-1

In one of my robustness test, I want to perform cross validation of the partial dependence plot but I don't know where to start. My model is regression tree and I had partial dependence plots based on the whole dataset. My questions are:

  1. If I randomly divide the dataset in 10 random samples, and calculate the partial dependence of variable X on Y based on each random sample, how can I average results of 10 samples to come up with one plot? I cannot find any built-in function in python or R to do that?

  2. The same task as above, however, I would like to draw partial dependence plot of 2-way interaction, for example, variables X1 and X2 on Y?

Thank you.

lindo
  • 27
  • 5
  • It's easier to help you if you include a simple [reproducible example](https://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example) with sample input that can be used to test and verify possible solutions. It's unclear how your data and models are currently set up. – MrFlick Jan 13 '21 at 20:16
  • Could you just look at individual conditional expectation (ICE) plots? The `ICEbox` package in R does these. – DaveArmstrong Jan 13 '21 at 20:21
  • As the pdp curves are an average of the ICE plots, averaging averages makes little sense as it would produce a very similar plot as if you only calculate the original pdp plot (unless the ten random sample groups overlaps greatly and the averages are very biased to some samples). As @DaveArmstrong suggests, use ICE plots or the ten pdp plots of each group to see if the curves differs from each other (you could think of some distance to measure how alike are the pdp curves and use it to compare the robustness of the pdp curves for different models) – JaiPizGon Jan 13 '21 at 21:17
  • 1
    Thinking more about @JaiPizGon's useful answer, I wonder what it even means for your PDP to be "robust". It sounds like what you might really be interested in is the variance of the ICE curves around the PDP. One thing you could do is for each evaluation point of the x-axis, you could calculate the variance of the ICE curve predictions. Or you could get some sense of variability around the PDP by bootstrapping the ICE curves to make bootstrapped PDPs. – DaveArmstrong Jan 13 '21 at 21:27
  • Thank you for the helpful comments! I will look at ICE curves – lindo Jan 14 '21 at 03:37

1 Answers1

1

Further to my answer in the comments, if you wanted to look at the variance of the ice curves, you could bootstrap them like this:

library(pdp)
library(randomForest)
library(ICEbox)
data(boston)
X <- as.data.frame(model.matrix(cmedv ~ ., data=boston)[,-1])
y <- model.response(model.frame(cmedv ~ ., data=boston))
boston.rf <- randomForest(x=X, y=y)
bice <- ice(boston.rf, X=X, predictor = "lstat") 

res <- NULL
for(i in 1:1000){
  inds <- sample(1:nrow(bice$ice_curves), 
                 nrow(bice$ice_curves), 
                 replace=TRUE)
  res <- rbind(res, colMeans(bice$ice_curve[inds, ]))
}

out <- data.frame(
  fit = colMeans(bice$ice_curves), 
  lwr = apply(res, 2, quantile, .025),
  upr = apply(res, 2, quantile, .975), 
  x=bice$gridpts
)

library(ggplot2)
ggplot(out, aes(x=x, y=fit, ymin=lwr, ymax=upr)) + 
  geom_ribbon(alpha=.25) + 
  geom_line() + 
  theme_bw() + 
  labs(x="lstat", y="Prediction")

enter image description here

Or, you could look at the different quantiles of the ice plots for each evaluation point.

tmp <- t(apply(bice$ice_curves, 
             2, 
             quantile, c(0, .025, .05, .1, .25, .5, .75, .9, .95, .975, 1)))

head(tmp)
tmp <- as.data.frame(tmp)
names(tmp) <- c("l1", "l2", "l3", "l4", "l5", 
                "med", "u1", "u2", "u3", "u4", "u5")

tmp$x <- bice$gridpts

ggplot(tmp, aes(x=x, y=med)) + 
  geom_ribbon(aes(ymin=l1, ymax=u1), alpha=.2) + 
  geom_ribbon(aes(ymin=l2, ymax=u2), alpha=.2) + 
  geom_ribbon(aes(ymin=l3, ymax=u3), alpha=.2) + 
  geom_ribbon(aes(ymin=l4, ymax=u4), alpha=.2) + 
  geom_ribbon(aes(ymin=l5, ymax=u5), alpha=.2) + 
  geom_line() + 
  theme_bw() + 
  labs(x="lstat", y="Prediction")

enter image description here

DaveArmstrong
  • 18,377
  • 2
  • 13
  • 25