0

So I am working on a decision tree within a SkLearn Pipeline. The model works fine. However, I am not able to plot the decision tree. I am not sure which object to use by calling the .plot method.

Here is my code to create the Decision Tree Model:

from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import (
    OneHotEncoder, PowerTransformer, StandardScaler
  )

# Build categorical preprocessor
categorical_cols = X.select_dtypes(include="object").columns.to_list()
categorical_pipe = make_pipeline(
    OneHotEncoder(sparse=False, handle_unknown="ignore")
  )

# Build numeric processor
to_log = ["SA13_peopleHH"]
to_scale = ["SA11_age"]
numeric_pipe_1 = make_pipeline(PowerTransformer())
numeric_pipe_2 = make_pipeline(StandardScaler())

# Full processor
full = ColumnTransformer(
    transformers=[
        ("categorical", categorical_pipe, categorical_cols),
        ("power_transform", numeric_pipe_1, to_log),
        ("standardization", numeric_pipe_2, to_scale),
    ]
)

# Final pipeline combined with DecisionTree
pipeline = Pipeline(
    steps=[
        ("preprocess", full),
        (
            "base",
            DecisionTreeClassifier(),
        ),
    ]
)
# Fit
_ = pipeline.fit(X_train, y_train)

That is how I would call the .plot function:

tree.plot_tree(pipeline)
desertnaut
  • 57,590
  • 26
  • 140
  • 166

1 Answers1

0

From this: Getting model attributes from pipeline

I think, tree.plot_tree(pipeline['base']) will work

Dmitry
  • 160
  • 9