18

I am using the gbm function in R (gbm package) to fit stochastic gradient boosting models for multiclass classification. I am simply trying to obtain the importance of each predictor separately for each class, like in this picture from the Hastie book (the Elements of Statistical Learning) (p. 382).

enter image description here

However, the function summary.gbm only returns the overall importance of the predictors (their importance averaged over all classes).

Does anyone know how to get the relative importance values?

Community
  • 1
  • 1
Antoine
  • 1,649
  • 4
  • 23
  • 50

3 Answers3

12

I think the short answer is that on page 379, Hastie mentions that he uses MART, which appears to only be available for Splus.

I agree that the gbm package doesn't seem to allow for seeing the separate relative influence. If that's something you're interested in for a mutliclass problem, you could probably get something pretty similar by building a one-vs-all gbm for each of your classes and then getting the importance measures from each of those models.

So say your classes are a, b, c, & d. You model a vs. the rest and get the importance from that model. Then you model b vs. the rest and get the importance from that model. Etc.

Tchotchke
  • 3,061
  • 3
  • 22
  • 37
  • Almost 3 years have passed but so far there is no answer. Do you have any additional hint besides the effective work-around you proposed in this answer? – Seymour Apr 09 '18 at 21:38
  • 1
    Unfortunately not - I haven't looked into it much over the years and I've primarily been working in Python – Tchotchke Apr 10 '18 at 14:07
  • Does python offer this kind of solution? – Seymour Apr 10 '18 at 14:19
  • @Tchotchke What do you think of my method of using the error reduction for each tree (see answer below)? I am using this in my work and I would really appreciate any thoughts you might have. – see24 Nov 01 '18 at 17:57
8

Hopefully this function helps you. For the example I used data from the ElemStatLearn package. The function figures out what the classes for a column are, splits the data into these classes, runs the gbm() function on each class and plots the bar plots for these models.

# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {
    
    class_column <- as.character(data[,column])
    class_values <- names(table(class_column))
    class_indexes <- sapply(class_values, function(x) which(class_column == x))
    split_data <- lapply(class_indexes, function(x) marketing[x,])
    object <- lapply(split_data, function(x) gbm(formula, data = x))
    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))
    
    nobjs <- length(class_values)
    for( i in 1:nobjs ) {
        tmp <- rel.inf[[i]]
        tmp.names <- row.names(tmp)
        tmp <- tmp$rel.inf
        names(tmp) <- tmp.names
        
        barplot(tmp, horiz=TRUE, col='red',
                xlab="Relative importance", main=paste0("Class = ", class_values[i]))
    }
    rel.inf
}

par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)

`

output

Community
  • 1
  • 1
nathanesau
  • 1,681
  • 16
  • 27
  • 1
    The interpretation for this example would be that age greatly effects male income and marital status greatly effects female income – nathanesau Aug 14 '15 at 18:10
  • 1
    thank you very much for this helpful answer. Let me play with your commands in detail before I accept the answer/award the bounty. Also, from a theoretical standpoint, I am wondering whether it is valid to compare the influence that variables have for two separate models... – Antoine Aug 15 '15 at 18:19
  • 1
    In fact, it is the same model, just on two subsets of the data. Why would this be invalid? – nathanesau Aug 15 '15 at 18:42
  • 2
    we are using the same algorithm in both cases, granted, but in the end we obtain two different models, since the data sets are different. If you compare the final equations (Boosting is similar to a generalized additive model), they won't be the same. So, it's not like we were comparing the relative importance of variables in predicting each class for a given, unique model. – Antoine Aug 15 '15 at 20:29
  • 1
    Agree - when I proposed this solution above it was an approximation of the solution you were looking for - I don't think it's quite doing the same thing as Hastie did, but it probably gets close enough (and is the easiest thing to do out-of-the-box in R that I could think of) – Tchotchke Aug 16 '15 at 16:26
2

I did some digging into how the gbm package calculates importance and it is based on the ErrorReduction which is contained in the trees element of the result and can be accessed with pretty.gbm.trees(). Relative influence is obtained by taking the sum of this ErrorReduction over all trees for each variable. For a multiclass problem there are actually n.trees*num.classes trees in the model. So if there are 3 classes you can calculate the sum of the ErrorReduction for each variable over every third tree to get the importance for one class. I have written the following functions to implement this and then plot the results:

Get Variable Importance By Class

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

Plot Variable Importance By Class

In my real use for this I have over 40 features so I give an option to specify the number of features to plot. I also couldn't use faceting if I wanted the plots to be sorted separately for each class, which is why I used gridExtra.

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

Try It

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

Seems to give the same results as the built in relative.influence function if you sum the results over all the classes.

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9
see24
  • 1,097
  • 10
  • 21
  • thanks, I'll take a close look in the weeks to come. In the meantime +1 for sharing your code! The `gbm` package has been orphaned, and it seems the most recent version lives on GitHub as `gbm3`:https://github.com/gbm-developers/gbm3 – Antoine Aug 22 '18 at 11:23
  • Great! I have opened a github [issue](https://github.com/gbm-developers/gbm3/issues/154#issue-352941598) on the gbm3 page so hopefully they will add this functionality to the new version – see24 Aug 22 '18 at 12:56
  • @Antoine have you had a chance to look at this yet? I've started using it in my work so if you see any issues I would be grateful to hear about them! – see24 Sep 17 '18 at 18:24
  • I'm just back from a two-week vacation. I won't have time to look at this in the very short term, but I'll try as soon as I have a chance – Antoine Sep 17 '18 at 20:57