42

I have this simple code:

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

tree.plot_tree(clf.fit(X, y))
plt.show()

And the result I get is this graph: enter image description here

How do I make this graph legible? I'm using PyCharm Professional 2019.3 as my IDE.

Artur
  • 614
  • 1
  • 6
  • 9

4 Answers4

62

I think the setting you are looking for is fontsize. You have to balance it with max_depth and figsize to get a readable plot. Here is an example

from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

# load data
X, y = load_iris(return_X_y=True)

# create and train model
clf = tree.DecisionTreeClassifier(max_depth=4)  # set hyperparameter
clf.fit(X, y)

# plot tree
plt.figure(figsize=(12,12))  # set plot size (denoted in inches)
tree.plot_tree(clf, fontsize=10)
plt.show()

enter image description here

If you want to capture structure of the whole tree I guess saving the plot with small font and high dpi is the solution. Then you can open a picture and zoom to the specific nodes to inspect them.

# create and train model
clf = tree.DecisionTreeClassifier()
clf.fit(X, y)

# save plot
plt.figure(figsize=(12,12))
tree.plot_tree(clf, fontsize=6)
plt.savefig('tree_high_dpi', dpi=100)

Here is an example of how it looks like on the bigger tree.

enter image description here

enter image description here

fpersyn
  • 1,045
  • 1
  • 12
  • 19
9

What about setting the size of the image before hand:

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

fig, ax = plt.subplots(figsize=(10, 10))  # whatever size you want
tree.plot_tree(clf.fit(X, y), ax=ax)
plt.show()
glemaitre
  • 963
  • 6
  • 7
  • This does nothing to actually fit the plot_tree to make it legible like the OP wanted. All this does is extend the subplot, and make it fit more items, but does not extend the subplot to an extent where anything is readable as a dynamic way would. – Vaidøtas I. Jan 20 '20 at 13:19
0

Try this:

plt.figure(figsize=(12,12))
tree.plot_tree(clf, fontsize=10)
plt.show()
  • 1
    Your answer could be improved with additional supporting information. Please [edit] to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Jul 21 '22 at 09:05
0

The problem is solved if you set the size before-hand:

from sklearn.tree import plot_tree, export_text
fig = plt.figure(figsize=(25,20))
_ = plot_tree(clf)
drGabriel
  • 548
  • 6
  • 5