4

Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART package. Im extracting tree information from BART package and turning it into a data frame that resembles the one shown below (i.e., df). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.

For example, my data frame of trees looks like this:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

Visually, these trees would look like:

decision trees

The trees are being drawn left-first when traversing down df. Additionally, all splits are binary splits. So each node will have 2 children.

So, if we create some data that looks like this:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)

Im trying to find which of the observations of dat fall into which node?

Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:

lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]

 namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
    dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
    dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
    dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
    dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
    dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
    dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]

I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??

Electrino
  • 2,636
  • 3
  • 18
  • 40
  • Why don't you compute the number of each observations in a given node during the growing of the tree, and add it as a new variable in your data frame? That is how I did when I coded a decision tree. Also, we do not know the path each variable should take down the tree. For instance, we see from `df` that the first tree splits once on `x1` and once on `x2`, but from the data frame it is not clear in which order, if sequentially, ecc. – riccardo-df Jan 23 '22 at 19:38
  • Im not actually growing the trees. Despite my comment about not using packages, I am in fact extracting the tree data from several different decision tree packages (all of which provide varying information about the trees). Im taking that info and creating a data frame of trees that I can manipulate. One of the things I cant extract from the packages is what observations are in each node. So, I was trying to solve this issue using my dataframe of trees – Electrino Jan 23 '22 at 19:43
  • 1
    Then you should provide more information about what you are actually doing. From what we have, it is impossible even to just draw the pictures you attached. Can you edit your post adding a reproducible example? – riccardo-df Jan 23 '22 at 19:45
  • Im not sure what else I could add to the question!? This is the data I have to work with. If it helps, the trees are drawn left first as you traverse down `df`. I have been able to draw trees using this structure, but the code used to draw the trees is far too large to include here. – Electrino Jan 23 '22 at 20:24
  • 1
    For starters, I would add the _several different decision tree packages_ you are extracting tree data from, and at least an example of how you are extracting them. – riccardo-df Jan 23 '22 at 20:38
  • Ive edited the question to mention one of the packages im using. However, the code to extract the tree data is, again, far too long to include here. Additionally, I need to work with the data frame of trees (`df`) like I have provided. – Electrino Jan 23 '22 at 21:16
  • @Electrino So do you already have this `df`? That is: do you already have a way of reliably converting the tree into a data frame of that format? – Greg Feb 04 '22 at 15:36
  • @Electrino Also, how many children may each node have? Exactly two? Either one or two? More than two? – Greg Feb 04 '22 at 15:38
  • 1
    I should have mentioned that in my question. Its a binary tree, so each node will have exactly two children. I'll add that info to the question – Electrino Feb 04 '22 at 15:40
  • 1
    If you're generating the trees using BART, post your BART code – Hong Ooi Feb 05 '22 at 11:25

2 Answers2

2

It seems that we can do "rolling splits" to get what you are looking for. The logic is as follows.

  1. Start with a stack with only one dataframe dat.
  2. For each pair of variableName and splitValue, if they are not NAs, split the top dataframe on that stack into two sub dataframes identified by variableName <= splitValue and variableName > splitValue (the former on top of the latter); if they are NAs, then simply pop the top dataframe.

Here is the code. Note that this kind of state-dependent computation is hard to vectorize. It's thus not what R is good at. If you have a lot of trees and the code performance becomes a serious concern, I'd suggest rewriting the code below using Rcpp.

eval_node <- function(df, x, v) {
  out <- vector("list", length(x))
  stk <- vector("list", sum(is.na(x)))
  pos <- 1L
  stk[[pos]] <- df
  for (i in seq_along(x)) {
    if (!is.na(x[[i]])) {
      subs <- pos + c(0L, 1L)
      stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
      names(stk)[subs] <- trimws(paste0(
        names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
      ), "left", ",")
      out[[i]] <- rev(stk[subs])
      pos <- pos + 1L
    } else {
      out[[i]] <- stk[pos]
      stk[[pos]] <- NULL
      pos <- pos - 1L
    }
  }
  out
}

Then you can apply the function like this.

library(dplyr)

df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))

Output

# A tibble: 15 x 4
# Groups:   treeNo [3]
   variableName splitValue treeNo node            
   <chr>             <dbl>  <dbl> <list>          
 1 x2                0.542      1 <named list [2]>
 2 x1                0.126      1 <named list [2]>
 3 NA               NA          1 <named list [1]>
 4 NA               NA          1 <named list [1]>
 5 NA               NA          1 <named list [1]>
 6 x2                0.655      2 <named list [2]>
 7 NA               NA          2 <named list [1]>
 8 NA               NA          2 <named list [1]>
 9 x5                0.418      3 <named list [2]>
10 x4                0.234      3 <named list [2]>
11 NA               NA          3 <named list [1]>
12 NA               NA          3 <named list [1]>
13 x3                0.747      3 <named list [2]>
14 NA               NA          3 <named list [1]>
15 NA               NA          3 <named list [1]>

, where node looks like this

[[1]]
[[1]]$`x2<=0.542`
          x1        x2        x3        x4        x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[1]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034

[[2]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034


[[4]]
[[4]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[5]]
[[5]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[6]]
[[6]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[6]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[7]]
[[7]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139


[[8]]
[[8]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[9]]
[[9]]$`x5<=0.418`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9  0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[9]]$`x5>0.418`
          x1        x2        x3        x4        x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[10]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[12]]
[[12]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[13]]
[[13]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318

[[13]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[14]]
[[14]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318


[[15]]
[[15]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
ekoam
  • 8,744
  • 1
  • 9
  • 22
  • This is a great answer. Im in the process of trying to figure out how to extract the row indices from `node` when using your code. The tricky part is the `node` object has duplicates. For example `df$node[[2]]` contains the same info as `df$node[[3]]` and `df$node[[4]]` combined. I am trying to remove any duplicates. I ultimately want to see how many times pairs of observations (i.e., row indices) appear in the same node. For example, observations 3 and 4 appear together 3 times. But due to the duplication, observations 3 and 4 appear 4 time in the `node` object. But this is a great start – Electrino Feb 07 '22 at 02:09
  • I don't think they are duplicates. Although `node` is somehow a badly named column, it is really just to give you all the information you need for each `variableName`. Note that `variableName` could represent either a node or a leaf (i.e., `NA`) in your tree. The information for a node is the data for each branch and the information for a leaf is the branch itself; that's why we need some repeats. If you want to "hide" the information for leaves, comment out this line `out[[i]] <- stk[pos]`. If you want to remove leaves from the data, then `df %>% filter(!is.na(variableName))`. @Electrino – ekoam Feb 07 '22 at 02:29
  • You're right! I wasn't thinking correctly when I commented about the duplicates. You don't have duplicates. I'm marking this the correct answer. It solves my problem. Thanks for your time and effort – Electrino Feb 07 '22 at 02:44
0

There is still much room for optimization, however this is my attempt. Your trees seem to be structured in a depth-first fashion with the left children always following parent node:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

Given the data to be matched:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)
dat
##>           x1        x2        x3        x4        x5
##>1  0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2  0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3  0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4  0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5  0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6  0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7  0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8  0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9  0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859

makeTree is a higher-order function that returns a function which in turn will map a row of values to a node:

makeTree <- function(dat, r = 1) {
  ## the argument dat is a dataframe representation
  ## of a single tree as in the example
  ## return a list of two elements: size and fn. 
  ## - size is the number of cells taken by the 
  ##   node and its descendants. 
  ## - fn is a function of one argument (either a list or
  ##   a row of a dataframe) that returns the index of the 
  ##   node matching argument. More precisely the column Id
  ##   in dat.    
  stopifnot(r <= nrow(dat))
  vname <- pull(dat,variableName)[r]
  splitVal <- pull(dat, splitValue)[r]
  if (is.na(vname)) {
    ## terminal node
    ## print(sprintf("terminal node: %i", r))
    res <- list(size = 1, # offset to access right node
                fn = function(z) {
                  pull(dat, "id")[r]
                })
    return(res)
  } else {
    ##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
    ## compute the left and right functions
    ## note that the tree is traversed depth-first 
    fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
                                   #caller
    fnright <- makeTree(dat, r + fnleft$size + 1 )
    return(list(size = fnleft$size + fnright$size + 1,
                fn = function(z) {
                  if (z[vname] <= splitVal)
                    fnleft$fn(z)
                  else
                    fnright$fn(z)
                }))
  }
}

Now makeTree is applied to each tree to produce a list of matching functions:

treefns <- df |>
  mutate(id = row_number()) %>%
  group_by(treeNo) |>
  group_split()    |>
  purrr::map(makeTree) |>
  purrr::map("fn")

Finally, each row of your dataframe dat is matched to a node of the tree:

apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
  t() |>
  data.frame() |>
  rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
  cbind(dat) |>
  pivot_longer(cols = starts_with("TREE"),
               names_to = "TREE",
               values_to = "NODE")  |>
  sample_n(10)

##> A tibble: 10 x 7
##>       x1    x2    x3    x4    x5 TREE   NODE
##>    <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170  0.690 0.278 0.130 0.307 TREE3    11
##> 2 0.170  0.690 0.278 0.130 0.307 TREE2     8
##> 3 0.370  0.358 0.882 0.629 0.884 TREE2     7
##> 4 0.308  0.625 0.536 0.488 0.331 TREE1     5
##> 5 0.370  0.358 0.882 0.629 0.884 TREE1     4
##> 6 0.552  0.280 0.538 0.349 0.778 TREE3    14
##> 7 0.547  0.359 0.549 0.990 0.208 TREE1     4
##> 8 0.370  0.358 0.882 0.629 0.884 TREE3    15
##> 9 0.547  0.359 0.549 0.990 0.208 TREE2     7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2     7
Stefano Barbi
  • 2,978
  • 1
  • 12
  • 11