I have a very similar example that I'm trying out, it's based on a ML how-to book which is working with a Taiwan Credit Card dataset predicting default risk. My setup is as follows:
from six import StringIO
from sklearn.tree import export_graphviz
from IPython.display import Image
import pydotplus
Then creating the decision tree plot is done in this way:
dot_data = StringIO()
export_graphviz(decision_tree=class_tree,
out_file=dot_data,
filled=True,
rounded=True,
feature_names = X_train.columns,
class_names = ['pay','default'],
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
I think it's all coming from the out_file=dot_data
argument but cannot figure out where the file path is created and stored as print(dot_data.getvalue())
did not show any pathname.
In my research I came across sklearn.plot_tree() which seems to do everything that the graphviz does. So I took the above exporet_graphviz arguments and were matching arguments were in the .plot_tree method I added them.
I ended up with the following which created the same image as was found in the text:
from sklearn import tree
plt.figure(figsize=(20, 10))
tree.plot_tree(class_tree,
filled=True, rounded=True,
feature_names = X_train.columns,
class_names = ['pay','default'],
fontsize=12)
plt.show()