This is my first time asking a question here so please be kind... and I apologise if this isn't the correct place to ask this question (the line between programming in R and a stats question can be a bit blurred to me) - if so I will happily try stackexchange.
I am running cforest in R (party package) to predict a numerical response variable, and using the excellent 'get_cTree' workaround suggested by @Marco Sandri here: https://stackoverflow.com/a/34534978/9989544 to generate a tree to try to understand the rules it's using to split (that's of interest to me even though my main focus is variable importance).
I was expecting the weights, across all of the nodes, to sum to my total sample size, which is what happens if I run a single 'ctree'.
However, what is actually happening, when using Marco Sandri's get_cTree code is that multiple pairs of nodes weights sum to my sample size, and the remaining weights do not sum to my sample size at all. The overall total of weights is more than my total sample size.
Is this a natural consequence of trying to get a tree out of a conditional forest - i.e. it isn't truly partioning the data into individual nodes? - or is this something that could be resolved with programming?
Here is an example (get_cTree code from Marco Sandri). For the iris dataset, n=150. The sum of the weights for the nodes that I get for the cforest is 566, and it's 150 using ctree (party package).
library(party)
update_tree <- function(x, dt) {
x <- update_weights(x, dt)
if(!x$terminal) {
x$left <- update_tree(x$left, dt)
x$right <- update_tree(x$right, dt)
}
x
}
update_weights <- function(x, dt) {
splt <- x$psplit
spltClass <- attr(splt,"class")
spltVarName <- splt$variableName
spltVar <- dt[,spltVarName]
spltVarLev <- levels(spltVar)
if (!is.null(spltClass)) {
if (spltClass=="nominalSplit") {
attr(x$psplit$splitpoint,"levels") <- spltVarLev
filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)]
} else {
filt <- (spltVar <= splt$splitpoint)
}
x$left$weights <- as.numeric(filt)
x$right$weights <- as.numeric(!filt)
}
x
}
get_cTree <- function(cf, k=1) {
dt <- cf@data@get("input")
tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
tr_updated <- update_tree(tr, dt)
new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses,
cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
}
attach(iris)
SepalLength <- as.numeric(iris$Sepal.Length)
SepalWidth <- as.numeric(iris$Sepal.Width)
PetalLength <- as.numeric(iris$Petal.Length)
PetalWidth <- as.numeric(iris$Petal.Width)
Species <- as.factor(iris$Species)
mtry=ceiling(sqrt(4))
set.seed(1)
iris_cforest <- cforest(PetalLength~SepalLength+SepalWidth+PetalWidth+Species,controls=cforest_unbiased(ntree=1000,mtry=mtry))
iristree <- get_cTree(iris_cforest)
iristree
plot(iristree)
set.seed(1)
iris_ctree <- ctree(PetalLength~SepalLength+SepalWidth+PetalWidth+Species)
iris_ctree
plot(iris_ctree)