5

I just built a basic classification model with package caret using the "xgbTree" method (Extreme Gradient Boosting). It has great accuracy (3 classes) but I can't see the rules or plot the tree.

Does any one know how to plot the tree, when it is built with caret? I tried using thexgb.plot.tree function from the xgboost package, but I get an error indicating it can't be plotted because my model is not an object of class xgb.Booster generated by the xgb.train function. Is there a way I can coerce the model I built in caret into an xgb.Booster object?

I appreciate any help.

David Heckmann
  • 2,899
  • 2
  • 20
  • 29
HunkyGoon
  • 51
  • 1
  • 2
  • Please provide a [reproducible example](http://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example). – David Heckmann Apr 28 '17 at 17:38
  • 1
    `library(caret) #Building (training) the model tic() myegb<-train(Confirmed.Diagnosis~.,method="xgbTree",data=training,na.action=na.pass) names(myegb) print(myegb$finalModel)` – HunkyGoon Apr 28 '17 at 17:44

2 Answers2

3

I was having the same issue and when I looked into ? help for xgb.plot.tree the first parameter is a vector of feature names and then you must specific the model parameter.

xgb.plot.tree(model = myegb$finalModel)

The above will produce the tree diagram without feature names; will use column index.

To add feature names:

xgb.plot.tree(feature_names = myegb$finalModel$params, model = myegb$finalModel)
Matt L.
  • 397
  • 4
  • 12
0

You need to do: xgb.plot.tree(model = myegb$finalModel,trees = tree_index)

tree_index is used to specify the index of the tree you want to plot, otherwise all the trees are going to be plot in one figure and you will lose the details. In xgb.plot.tree, index starts from 0, not 1.