0

working on the Kaggle Titanic data set. I'm trying to understand decision trees better, I've worked with linear regressions a good bit but never decision trees. I'm trying to create a visualization in python for my tree. Something isn't working though. Check my code below.

import pandas as pd
from sklearn import tree
from sklearn.datasets import load_iris
import numpy as np


train_file='.......\RUN.csv'
train=pd.read_csv(train_file)

#impute number values and missing values
train["Sex"][train["Sex"] == "male"] = 0
train["Sex"][train["Sex"] == "female"] = 1
train["Embarked"] = train["Embarked"].fillna("S")
train["Embarked"][train["Embarked"] == "S"]= 0
train["Embarked"][train["Embarked"] == "C"]= 1
train["Embarked"][train["Embarked"] == "Q"]= 2
train["Age"] = train["Age"].fillna(train["Age"].median())
train["Pclass"] = train["Pclass"].fillna(train["Pclass"].median())
train["Fare"] = train["Fare"].fillna(train["Fare"].median())

target = train["Survived"].values
features_one = train[["Pclass", "Sex", "Age", "Fare","SibSp","Parch","Embarked"]].values


# Fit your first decision tree: my_tree_one
my_tree_one = tree.DecisionTreeClassifier(max_depth = 10, min_samples_split = 5, random_state = 1)

iris=load_iris()

my_tree_one = my_tree_one.fit(features_one, target)

tree.export_graphviz(my_tree_one, out_file='tree.dot')

How do I actually see the decision tree? Trying to visualize it.

Help appreciated!

cchamberlain
  • 17,444
  • 7
  • 59
  • 72
Josh Dautel
  • 143
  • 2
  • 16

3 Answers3

2

did you check: http://scikit-learn.org/stable/modules/tree.html mentions how to plot the tree as PNG image :

 from IPython.display import Image 
 import pydotplus
 dot_data = tree.export_graphviz(my_tree_one, out_file='tree.dot')  
 graph = pydotplus.graph_from_dot_data(dot_data)  `
 Image(graph.create_png())
RMS
  • 1,350
  • 5
  • 18
  • 35
  • >>> import os >>> os.unlink('iris.dot') – Josh Dautel Dec 15 '16 at 15:50
  • It says to do this ^. However that just deletes the file. any ideas? I also don't have pydotplus. I tried downloading it using pip but it didn't work. – Josh Dautel Dec 15 '16 at 15:50
  • I think the problem is Graphiz and you should download it : http://www.graphviz.org/Download..php http://stackoverflow.com/questions/18438997/why-is-pydot-unable-to-find-graphvizs-executables-in-windows-8. First install graphiz then pydot. Or use linux. I will get back to it slightly later. – RMS Dec 16 '16 at 09:54
0

From wikipedia:

The DOT language defines a graph, but does not provide facilities for rendering the graph. There are several programs that can be used to render, view, and manipulate graphs in the DOT language:

Graphviz - A collection of libraries and utilities to manipulate and render graphs

Canviz - a JavaScript library for rendering dot files.

Viz.js - A simple Graphviz JavaScript client

Grappa - A partial port of Graphviz to Java.[4][5]

Beluging - A Python & Google Cloud based viewer of DOT and Beluga extensions. [1]

Tulip can import dot files for analysis

OmniGraffle can import a subset of DOT, producing an editable document. (The result cannot be exported back to DOT, however.)

ZGRViewer, a GraphViz/DOT Viewer link

VizierFX, A Flex graph rendering library link

Gephi - an interactive visualization and exploration platform for all kinds of networks and complex systems, dynamic and hierarchical graphs

So any one of these programs would be capable of visualizing your tree.

Community
  • 1
  • 1
Joshua Howard
  • 876
  • 1
  • 12
  • 25
  • I'm using graphviz already but I can't get it to show up as an image. It just writes it to the .dot file. I've tried changing ti to a pdf but can't seem to get it to work. – Josh Dautel Dec 15 '16 at 16:23
  • I believe that this should just write the .dot file. You then must use one of the applications listed to view the .dot file. I personally like Gephi. – Joshua Howard Dec 15 '16 at 20:23
0

I made a visualization using bar plots. The first plot indicates the distribution of the classes. The first title represents the first split criterium. All data satisfying this criterium result in the left underlying subplot. If not, the right plot is the result. Thus, all titles indicate the split criterium for the next split.

The percentages are the values from the initial distribution. Therefore, by looking at the precentages one can easily obtain how much from the inititial amount of data is left after a few splits.

Note, if you set max_depth high that this will entail a lot of subplot (max_depth, 2^depth)

Tree visualization using bar plots

Code:

def give_nodes(nodes,amount_of_branches,left,right):
    amount_of_branches*=2
    nodes_splits=[]
    for node in nodes:
        nodes_splits.append(left[node])
        nodes_splits.append(right[node])
    return (nodes_splits,amount_of_branches)

def plot_tree(tree, feature_names):
    from matplotlib import gridspec 
    import matplotlib.pyplot as plt
    from matplotlib import rc
    import pylab

    color = plt.cm.coolwarm(np.linspace(1,0,len(feature_names)))

    plt.rc('text', usetex=True)
    plt.rc('font', family='sans-serif')
    plt.rc('font', size=14)

    params = {'legend.fontsize': 20,
             'axes.labelsize': 20,
             'axes.titlesize':25,
             'xtick.labelsize':20,
             'ytick.labelsize':20}
    plt.rcParams.update(params)

    max_depth=tree.max_depth
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    fig = plt.figure(figsize=(3*2**max_depth,2*2**max_depth))
    gs = gridspec.GridSpec(max_depth, 2**max_depth)
    plt.subplots_adjust(hspace = 0.6, wspace=0.8)

    # All data
    amount_of_branches=1
    nodes=[0]
    normalize=np.sum(value[0][0])

    for i,node in enumerate(nodes):
        ax=fig.add_subplot(gs[0,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
        ax.set_title( features[node]+"$<= "+str(threshold[node])+"$")
        if( i==0): ax.set_ylabel(r'$\%$')
        ind=np.arange(1,len(value[node][0])+1,1)
        width=0.2
        bars= (np.array(value[node][0])/normalize)*100
        plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
        plt.xticks(ind, [int(i) for i in ind-1])
        pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))

    # Splits
    for j in range(1,max_depth):
        nodes,amount_of_branches=give_nodes(nodes,amount_of_branches,left,right)
        for i,node in enumerate(nodes):
            ax=fig.add_subplot(gs[j,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
            ax.set_title( features[node]+"$<= "+str(threshold[node])+"$")
            if( i==0): ax.set_ylabel(r'$\%$')
            ind=np.arange(1,len(value[node][0])+1,1)
            width=0.2
            bars= (np.array(value[node][0])/normalize)*100
            plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
            plt.xticks(ind, [int(i) for i in ind-1])
            pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))


    plt.tight_layout()
    return fig

Example:

X=[]
Y=[]
amount_of_labels=5
feature_names=[ '$x_1$','$x_2$','$x_3$','$x_4$','$x_5$']
for i in range(200):
    X.append([np.random.normal(),np.random.randint(0,100),np.random.uniform(200,500) ])
    Y.append(np.random.randint(0,amount_of_labels))

clf = tree.DecisionTreeClassifier(criterion='entropy',max_depth=4)
clf = clf.fit(X,Y )
fig=plot_tree(clf, feature_names)
M. Bon
  • 1