8

This question refers to Obtaining summary shap plot for catboost model with tidymodels in R. Given the comment below the question, the OP found a solution but did not share it with the community so far.

I want to analyze my tree ensembles fitted with the tidymodels package with SHAP value plots such as plots for single observations like

ttps://prnt.sc/CO_PC4aDUQA0

and to summarize the effect of all features of my dataset like

enter image description here

DALEXtra provides a function to create SHAP values for tidymodels explain.tidymodels(). force_plot from the fastshap package provide a wrapper for the plot function of the underlying python package SHAP. But I can't understand how to make the function work with the output of the explain.tidymodels() function.

Question : How can one generate such SHAP plots in R using tidymodels and explain.tidymodels?

MWE (for SHAP values with explain.tidymodels)

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data
mugdi
  • 365
  • 5
  • 17
  • 3
    I haven't had much luck with catboost and treesnip myself, but you might find it helpful to look at [this blog post](https://juliasilge.com/blog/board-games/). Especially pay attention to how to use tidymodels output as input for functions like those from SHAPforxgboost, using `extract_fit_engine()` and `bake()`. – Julia Silge Mar 29 '22 at 22:47
  • I guess one of the main problems with catboost is that there is, ttbomk, still no catboost implementation in R from the original authors that made it to CRAN and I doubt that they have intentions to do so. – mugdi Mar 30 '22 at 08:44
  • 3
    @JuliaSilge Something important to consider for SHAP value is [this ongoing debate](https://github.com/slundberg/shap/issues/2345) about one of the main assumptions [from the original paper of Lundberg et. al](https://arxiv.org/abs/1802.03888) beeing violated with the Tree algorithm! If you work in the scientific field a restriction of the validity of the results might be needed! – mugdi Mar 30 '22 at 09:14

1 Answers1

5

Perhaps this will help. At the very least, it is a step in the right direction.

First, ensure you have fastshap and reticulate installed (i.e., install.packages("...")). Next, set up a virtual environment and install shap (pip install ...). Also, install matplotlib 3.2.2 for the dependency plots (check out GitHub issues on this -- an older version of matplotlib is necessary).

RStudio has great information on virtual environment setup. That said, virtual environment setup requires more or less troubleshooting depending on the IDE of use. (Sadly, some work settings restrict the use of open source RStudio due to licensing.)

Docs for library(fastshap) are also helpful on this front.

Here's a workflow for lightgbm (from treesnip docs, lightly modified).

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

Prior to prediction we want to fit our workflow

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

Now we have a fit workflow and can predict. To use the fastshap::explain function, we need to create a predict function (this doesn't always hold: depending on the engine used it may or may not work out of the box -- see docs).

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

Let's get the mean prediction value (used below) while we're at it. This also serves as a check to ensure the function is functioning.

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

Now we create our explanations (shap values). Note the pred_wrapper and X arguments here (see fastshap github issues for other examples -- i.e. glmnet).

fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

This should produce a force plot.

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

This allows multiple, vertically stacked:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

Add link = "logit" for classification. Change display to "html" for Rmarkdown rendering.

Now for summary plots and dependency plots.

The trick is using reticulate to access the functions directly. Note that the same logic hold for libraries like transformers, numpy, etc.

First, for dependency plot.

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

See shap docs for explanation of rank(3) -- rank(1) etc will also work.

Unforunately it threw an error when I attempted naming the feature directly (i.e., "cut").

Now for the summary plot:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

Final note: rendering the plot repeatedly will produce buggy visualizations. Hopefully this provides a point of depature for catboost visualizations.

  • Great work thanks for the answer! Just a short addition for those people who have problems with saved and loaded workflows featuring a fitted lightgbm models . It seems that for some reason saving such a model with `write_rds()` will not save the actual lightgbm model. One has to extract and save the model from the workflow separately and combine them after loading to continue to work with them. To save such a model separately one could do something like : (1/2) – mugdi Apr 21 '22 at 16:16
  • `pull_lightgbm <- extract_fit_parsnip(final_model_cv) lightgbm::lgb.save(pull_lightgbm$fit, file = str_c(here::here(),'/data/shap/lightgbm/','lightgbm.model_',dataset_type,'_',y)) ` (2/2) – mugdi Apr 21 '22 at 16:16