10

I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabilities and classes.

However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.

Which means I want to got following.

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

Kind of code I'm looking for is

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

this is what needs to expanded into multiple lines possibly

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")
arindam adak
  • 101
  • 1
  • 6

1 Answers1

12

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

To obtain the predictions you describe above you can do:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

Illustration using the iris data and rpart():

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

And with ctree():

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7
Achim Zeileis
  • 15,710
  • 1
  • 39
  • 49
  • I'm trying to reproduce this example to make sure I understand everything. Can you explain the significance of using the indices 1, 51, and 101 in rp_pred[c(1, 51, 101), ] ? – Kyle. Oct 16 '15 at 17:51
  • I couldn't show the paths for all 150 observations in the `iris` dataset. So I simply picked the first observations of each species in the response. There is no deeper meaning associated with that. Will add a note to my answer. – Achim Zeileis Oct 16 '15 at 20:00
  • Ah, thanks! I thought it might be something like that, but I wanted to make sure I wasn't missing something. – Kyle. Oct 16 '15 at 20:53
  • @Kyle. Have you managed to implement this in Python? I am trying to do the exact same thing in Python, however I can't manage to find anything. I saw your question (https://datascience.stackexchange.com/questions/8440/extract-the-path-of-a-data-point-through-a-decision-tree-in-sklearn) and was hoping you might be the one person that knows how to do this in Python. – codiearcher May 27 '19 at 15:50
  • This function is awesome. Thank you. – igorkf Feb 07 '22 at 01:23