I am nearly certain that someone has done this before me.
The final code will not be in R but rather SQL.
require(rpart)
model=rpart.m=rpart(clazz~.,data=df)
printcp(rpart.m)
summary(rpart.m)
rpart.m
Produces something like so
n= 2500
node), split, n, deviance, yval
* denotes terminal node
1) root 2500 505.753600 0.2816000
2) x>=3.44898 250 0.000000 0.0000000 *
3) x< 3.44898 2250 483.726200 0.3128889
6) x< -1.44898 250 0.000000 0.0000000 *
7) x>=-1.44898 2000 456.192000 0.3520000
14) y< -1.44898 200 0.000000 0.0000000 *
etc...
I really don't want to write a parser for the text in order to generate SQL case statements.
I noticed https://www.r-bloggers.com/create-sql-rules-from-rpart-model/ but this only goes halfway. It is a case statement that gives the nodes but not the final predictions.
Any suggestions?
Update
I posted code as requested and I may have successfully adapted/fixed tomasgreif's code to include the predicted values.
#' Toy Data for a case that is very non-linear
#'create a circle data.frame
n=2500
df=expand.grid(x=seq(-2,4,length.out = floor(sqrt(n))),y=seq(-2,4,length.out = floor(sqrt(n))))
cx=mean(df$x)
cy=mean(df$y)
r=(max(df$x)-min(df$x))*0.35
thinness=0.6
df$clazz=(with(df,1/(1+abs(((x-cx)^2+(y-cy)^2)-r^2)*thinness)))
df$clazz[sample(nrow(df),nrow(df)*0.05)]=runif(nrow(df)*0.05) ##introduce noise
df$clazz=round(df$clazz)
#' a simple function to plot
plotxyfit=function(df,what='fit',numcolors=5) {
df$fit=df[[what]]
plot(df$x,df$y,col=heat.colors(5+numcolors)[round(2+numcolors*as.numeric(df$fit))],pch=1+round(as.numeric(df$fit)))
title(what)
}
plotxyfit(df,'clazz')
#' build a *tree* decision tree model
require(tree)
tree.m=tree(clazz~.,df)
summary(tree.m)
plot(tree.m)
text(tree.m,pretty=0)
tree.m
#' build a rpart decision tree model
require(rpart)
rpart.m=rpart(clazz~.,data=df)
printcp(rpart.m)
summary(rpart.m)
rpart.m
require(rpart.plot)
plot(rpart.m) # Will make a mess of the plot
text(rpart.m)
rpart.plot(rpart.m,tweak=1.5) # nicer plot
prp(rpart.m) #really nice plot
df$fit.t=predict(tree.m,df)
df$fit.rp=predict(rpart.m,df)
plotxyfit(df,'fit.t') #looks rather poor
plotxyfit(df,'fit.rp') #does an ok job
summary(df)
#' get output as pmml xml
require(pmml)
pmml(rpart.m)
#Adapted from https://gist.github.com/tomasgreif/6038822
#'
#' Rpart rules are changed to sql CASE statement.
#'
#' @param df data frame used for rpart model
#' @param model rpart model
#' @export
#' @examples
#' parse_tree(df=kyphosis,model=rpart(data=kyphosis,formula=Kyphosis~.))
#' parse_tree(df=mtcars,model=rpart(data=mtcars,formula=am~.))
#' parse_tree(df=iris,model=rpart(data=iris,formula=Species~.))
#' x <- german_data
#' x$gbbin <- NULL
#' model <- rpart(data=x,formula=gb~.)
#' parse_tree(x,model)
parse.tree.to.sql <- function (df=NULL, model=NULL) { #https://gist.github.com/tomasgreif/6038822
log <- capture.output({
rpart.rules <- path.rpart(model,rownames(model$frame)[model$frame$var=="<leaf>"])
})
args <- c("<=",">=","<",">","=")
rules_out <- "case "
i <- 1
for (rule in rpart.rules) {
rule_out <- character(0)
for (component in rule) {
sep <- lapply(args, function(x) length(unlist(strsplit(component,x)))) > 1
elements <- unlist(strsplit(component,(args[sep])[1]))
if(!(elements[1]=="root")) {
if (is.numeric(df[,elements[[1]]])) {
rule_out <- c(rule_out,paste(elements[1],(args[sep])[1],elements[2]))
} else {
rule_out <- c(rule_out,paste0(elements[1]," in (",paste0("'",unlist(strsplit(elements[2],",")),"'",collapse=","),")"))
}
}
}
rules_out <- c(rules_out, paste0("when ", paste(rule_out,collapse=" AND ")," then ",
sprintf("%f /*node %s */",rpart.m$frame$yval[row.names(rpart.m$frame)==names(rpart.rules)[i] ],names(rpart.rules)[i]) ))
if(i==length(rpart.rules)) rules_out <- c(rules_out," end ")
i <- i +1
}
sql_out <- paste(rules_out, collapse=" ")
sql_out
}
results=parse.tree.to.sql(df=df,model=rpart.m)
cat(results) ##if you see truncated on output, you must configure lines output in rstudio
Among other things, this produces:
case when y < -1.449 then 0.012000 /*node 2 */ when y >= -1.449 AND y >= 3.327 then 0.050000 /*node 6 */ when y >= -1.449 AND y < 3.327 AND x >= 3.449 then 0.025641 /*node 14 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x < -1.449 then 0.030769 /*node 30 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x < 2.469 AND y >= -0.3469 AND x >= -0.3469 then 0.079395 /*node 496 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x < 2.469 AND y >= -0.3469 AND x < -0.3469 then 0.753623 /*node 497 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x < 2.469 AND y < -0.3469 AND x < -0.5918 then 0.174603 /*node 498 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x < 2.469 AND y < -0.3469 AND x >= -0.5918 then 0.746667 /*node 499 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x >= 2.469 AND y < -0.5918 then 0.142857 /*node 250 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y < 2.469 AND x >= 2.469 AND y >= -0.5918 then 0.775000 /*node 251 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y >= 2.469 AND x >= 2.714 then 0.095238 /*node 126 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y >= 2.469 AND x < 2.714 AND x < -0.5918 then 0.163265 /*node 254 */ when y >= -1.449 AND y < 3.327 AND x < 3.449 AND x >= -1.449 AND y >= 2.469 AND x < 2.714 AND x >= -0.5918 then 0.814815 /*node 255 */ end