Is there a way to print a trained decision tree in scikit-learn? I want to train a decision tree for my thesis and I want to put the picture of the tree in the thesis. Is that possible?
3 Answers
There is a method to export to graph_viz format: http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
So from the online docs:
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
... out_file='tree.dot')
Then you can load this using graph viz, or if you have pydot installed then you can do this more directly: http://scikit-learn.org/stable/modules/tree.html
>>> from sklearn.externals.six import StringIO
>>> import pydot
>>> dot_data = StringIO()
>>> tree.export_graphviz(clf, out_file=dot_data)
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())
>>> graph.write_pdf("iris.pdf")
Will produce an svg, can't display it here so you'll have to follow the link: http://scikit-learn.org/stable/_images/iris.svg
Update
It seems that there has been a change in the behaviour since I first answered this question and it now returns a list
and hence you get this error:
AttributeError: 'list' object has no attribute 'write_pdf'
Firstly when you see this it's worth just printing the object and inspecting the object, and most likely what you want is the first object:
graph[0].write_pdf("iris.pdf")
Thanks to @NickBraunagel for the comment

- 376,765
- 198
- 813
- 562
-
7I get this error. `AttributeError: 'list' object has no attribute 'write_pdf'` How can I resolve this? – Ernest Soo Jun 01 '17 at 15:01
-
@EdChum can you kindly check this https://stackoverflow.com/questions/48880557/how-to-print-and-visualize-the-average-randomforestclassifier-of-all-trees-for-a – Feb 20 '18 at 09:27
-
@ErnestSoo (and anyone else running into your error: `pydot.graph_from_dot_data()` returns the desired `graph` (the `pydot.Dot` object) but it returns it within a `list`: so, access the list's first object to access the `pydot.Dot` object: `graph[0].write_pdf("iris.pdf")` – NickBraunagel Mar 11 '18 at 22:02
-
1@NickBraunagel as it seems a lot of people are getting this error I will add this as an update, it looks like this is some change in behaviour since I answered this question over 3 years ago, thanks – EdChum Mar 11 '18 at 22:21
-
1how would you do the same thing but on test data? – bernando_vialli Sep 17 '18 at 17:04
Although I'm late to the game, the below comprehensive instructions could be useful for others who want to display decision tree output:
Install necessary modules:
- install
graphviz
. I used conda's install package here (recommended overpip install graphviz
aspip
install doesn't include the actual GraphViz executables) - install
pydot
via pip (pip install pydot
) - Add the graphviz folder directory containing the .exe files (e.g. dot.exe) to your environment variable PATH
- run EdChum's above (NOTE:
graph
is alist
containing thepydot.Dot
object):
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph[0].write_pdf("iris.pdf") # must access graph's first element
Now you'll find the "iris.pdf" within your environment's default directory

- 1,559
- 1
- 16
- 30
There are 4 methods which I'm aware of for plotting the scikit-learn decision tree:
- print the text representation of the tree with
sklearn.tree.export_text
method - plot with
sklearn.tree.plot_tree
method (matplotlib
needed) - plot with
sklearn.tree.export_graphviz
method (graphviz
needed) - plot with
dtreeviz
package (dtreeviz
andgraphviz
needed)
The simplest is to export to the text representation. The example decision tree will look like:
|--- feature_2 <= 2.45
| |--- class: 0
|--- feature_2 > 2.45
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 4.95
| | | |--- feature_3 <= 1.65
| | | | |--- class: 1
| | | |--- feature_3 > 1.65
| | | | |--- class: 2
| | |--- feature_2 > 4.95
| | | |--- feature_3 <= 1.55
| | | | |--- class: 2
| | | |--- feature_3 > 1.55
| | | | |--- feature_0 <= 6.95
| | | | | |--- class: 1
| | | | |--- feature_0 > 6.95
| | | | | |--- class: 2
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- feature_1 <= 3.10
| | | | |--- class: 2
| | | |--- feature_1 > 3.10
| | | | |--- class: 1
| | |--- feature_2 > 4.85
| | | |--- class: 2
Then if you have matplotlib
installed, you can plot with sklearn.tree.plot_tree
:
tree.plot_tree(clf) # the clf is your decision tree model
The example output is similar to what you will get with export_graphviz
:
You can also try dtreeviz
package. It will give you much more information. The example:
You can find a comparison of different visualization of sklearn decision tree with code snippets in this blog post: link.

- 5,023
- 1
- 30
- 34