Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART
package. Im extracting tree information from BART
package and turning it into a data frame that resembles the one shown below (i.e., df
). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.
For example, my data frame of trees looks like this:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
Visually, these trees would look like:
The trees are being drawn left-first when traversing down df
. Additionally, all splits are binary splits. So each node will have 2 children.
So, if we create some data that looks like this:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
Im trying to find which of the observations of dat
fall into which node?
Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:
lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]
namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]
I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??