1

I'm unable to use the iml package in R to find shapley values for glmnet models.

It seems like the problem might be related to the fact that glmnet() and predict.glmnet() expect matrices, while the x.interest argument in iml::Shapley$new() expects a data frame, and so something is being incorrectly converted, but I'm not sure.

The most reasonable thing I've tried is below. Because of the following note in the iml::Predictor() documentation, I make sure my prediction function returns estimated probabilities for both classes: "Note: In case of classification, the model should return one column per class with the class probability."

library(dplyr)
library(iml)
library(glmnet)
df <- filter(iris, Species != 'setosa')
X <- as.matrix(select(train, -Species))
y <- droplevels(df$Species)
fit <- glmnet(X, y, family = 'binomial', lambda = 0.03)

predfun <- function(model, newdata) {
  preds <- predict(model, as.matrix(newdata), type = 'response') # probabilities
  return(cbind(1 - preds, preds)) # for both classes 
}

# Pass data frames, as requested
mod <- Predictor$new(fit, as.data.frame(X), predict.function = predfun) 
shapley <- Shapley$new(mod, x.interest = as.data.frame(X[1, ]))

This gives me the following: Error in predict.glmnet(model, as.matrix(newdata), type = "response"): The number of variables in newx must be 4

I'm not really sure what is being passed to predict.glmnet() that doesn't have four variables (it doesn't seem to have to do with an intercept from things I've tried). I've looked at the source code for Shapley$new() and also stepped for quite a while through a call via browser() but wasn't able to come up with anything useful.

Any ideas? Thank you!

1 Answers1

0

Not 100% sure how to solve this API nightmare...

You can try exact KernelSHAP:

library(dplyr)
library(glmnet)

df <- filter(iris, Species != 'setosa')
X <- as.matrix(select(df, -Species))
y <- droplevels(df$Species)
fit <- glmnet(X, y, family = 'binomial', lambda = 0.03)

library(kernelshap)
library(shapviz)
library(ggplot2)
library(patchwork)

s <- shapviz(kernelshap(fit, X, bg_X = X))
sv_importance(s, kind = "bee", show_numbers = TRUE)
sv_dependence(s, colnames(X), color_var = NULL) &
  ylim(-4, 4)
sv_waterfall(s, row_id = 1)

enter image description here enter image description here enter image description here

Michael M
  • 880
  • 7
  • 10