4

I am playing with tidymodels workflow for ctree with new bonsai package, an extension for modeling with partykit, here is my code:

pacman::p_load(tidymodels, bonsai, modeldata, finetune)

data(penguins)

doParallel::registerDoParallel()


split <- initial_split(penguins, strata = species)
df_train <- training(split)
df_test <- testing(split)

folds <- 
  # vfold_cv(train, strata = penguins)
  bootstraps(df_train, strata = species, times = 5) # if small number of records


tree_recipe <-
  recipe(formula = species ~ flipper_length_mm + island, data = df_train) 

tree_spec <-
  decision_tree(
    tree_depth = tune(),
    min_n = tune()
  ) %>%
  set_engine("partykit") %>%
  set_mode("classification") 

tree_workflow <- 
  workflow() %>% 
  add_recipe(tree_recipe) %>% 
  add_model(tree_spec) 

set.seed(8833)
tree_tune <-
  tune_sim_anneal(
    tree_workflow, 
    resamples = folds, 
    iter = 30,
    initial = 4,
    metrics = metric_set(roc_auc, pr_auc, accuracy))


final_workflow <- finalize_workflow(tree_workflow, select_best(tree_tune, "roc_auc"))

final_fit <- last_fit(final_workflow, split = split)

I understand that to extract a final fit model I need to:

final_model <-  extract_fit_parsnip(final_fit)

And then I can plot the tree.

plot(final_model$fit)

I would like to try a different plotting library that works with partykit:

library(ggparty)

ggparty(final_model$fit)+ 
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(
    gglist =  list(geom_bar(x = "", color = species),
                               xlab("species")),
                 # draw individual legend for each plot
                 shared_legend = FALSE
  )

But the ggparty code works up to the last line (without it the tree looks good, it prints without plots in final nodes).

It does not see the data inside the fitted model, namely, the response variable species.

    Error in layer(data = data, mapping = mapping, stat = stat, geom = GeomBar,  : 
  object 'species' not found

How can I extract the final fit from tidymodels, so that it contains the fitted values as it would if I had built a model without tidymodels workflow?

Jacek Kotowski
  • 620
  • 16
  • 49

1 Answers1

2

There are two problems in your code, only one of them related to tidymodels.

  1. The arguments to geom_bar() need to be wrapped in aes(), which is necessary both for plain ctree() output and for the result from the tidymodels workflow.

  2. The dependent variable in the output from the tidymodels workflow is not called species anymore but ..y (presumably a standardized placeholder employed in tidymodels). This can be seen from simply printing the object:

    final_model$fit
    ## Model formula:
    ## ..y ~ flipper_length_mm + island
    ## 
    ## Fitted party:
    ## [1] root
    ## ...
    

Addressing both of these (plus using the fill= instead of color= aesthetic) works as intended. (Bonus comment: autoplot(final_model$fit) also just works!)

ggparty(final_model$fit) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist =  list(
    geom_bar(aes(x = "", fill = ..y)),
    xlab("species")
  ))

ggparty visualization

Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
  • Thank you, great answer, helped me a lot. May I ask in addition what it would look like in the case of geom_col where aes(x = ..y, y = ???... what would be the names of variables for a version geom_col also to work? – Jacek Kotowski Sep 07 '22 at 07:53
  • 1
    You don't need an explicit `y`, the count is implied, isn't it? I think you just want `aes(x = ..y, fill = ..y)`. – Achim Zeileis Sep 07 '22 at 09:56