3

I am nearly certain that someone has done this before me.

The final code will not be in R but rather SQL.

require(rpart)
model=rpart.m=rpart(clazz~.,data=df)
printcp(rpart.m)
summary(rpart.m)
rpart.m

Produces something like so

n= 2500 
node), split, n, deviance, yval
      * denotes terminal node
  1) root 2500 505.753600 0.2816000  
    2) x>=3.44898 250   0.000000 0.0000000 *
    3) x< 3.44898 2250 483.726200 0.3128889  
      6) x< -1.44898 250   0.000000 0.0000000 *
      7) x>=-1.44898 2000 456.192000 0.3520000  
       14) y< -1.44898 200   0.000000 0.0000000 *
etc...

I really don't want to write a parser for the text in order to generate SQL case statements.

I noticed https://www.r-bloggers.com/create-sql-rules-from-rpart-model/ but this only goes halfway. It is a case statement that gives the nodes but not the final predictions.

Any suggestions?

Update

I posted code as requested and I may have successfully adapted/fixed tomasgreif's code to include the predicted values.

#' Toy Data for a case that is very non-linear
#'create a circle data.frame
n=2500
df=expand.grid(x=seq(-2,4,length.out = floor(sqrt(n))),y=seq(-2,4,length.out = floor(sqrt(n))))
cx=mean(df$x)
cy=mean(df$y) 
r=(max(df$x)-min(df$x))*0.35
thinness=0.6
df$clazz=(with(df,1/(1+abs(((x-cx)^2+(y-cy)^2)-r^2)*thinness)))
df$clazz[sample(nrow(df),nrow(df)*0.05)]=runif(nrow(df)*0.05) ##introduce noise
df$clazz=round(df$clazz)
#' a simple function to plot
plotxyfit=function(df,what='fit',numcolors=5) {
  df$fit=df[[what]]
  plot(df$x,df$y,col=heat.colors(5+numcolors)[round(2+numcolors*as.numeric(df$fit))],pch=1+round(as.numeric(df$fit)))
  title(what)
}
plotxyfit(df,'clazz')

#' build a *tree* decision tree model
require(tree)
tree.m=tree(clazz~.,df)
summary(tree.m)
plot(tree.m)
text(tree.m,pretty=0)
tree.m

#' build a rpart decision tree model
require(rpart)
rpart.m=rpart(clazz~.,data=df)
printcp(rpart.m)
summary(rpart.m)
rpart.m
require(rpart.plot)
plot(rpart.m)                   # Will make a mess of the plot
text(rpart.m)
rpart.plot(rpart.m,tweak=1.5) # nicer plot
prp(rpart.m) #really nice plot

df$fit.t=predict(tree.m,df)
df$fit.rp=predict(rpart.m,df)
plotxyfit(df,'fit.t') #looks rather poor
plotxyfit(df,'fit.rp') #does an ok job
summary(df)

#' get output as pmml xml
require(pmml)
pmml(rpart.m)


#Adapted from https://gist.github.com/tomasgreif/6038822
#'
#' Rpart rules are changed to sql CASE statement.
#'
#' @param df data frame used for rpart model
#' @param model rpart model
#' @export
#' @examples
#' parse_tree(df=kyphosis,model=rpart(data=kyphosis,formula=Kyphosis~.))
#' parse_tree(df=mtcars,model=rpart(data=mtcars,formula=am~.))
#' parse_tree(df=iris,model=rpart(data=iris,formula=Species~.))
#' x <- german_data
#' x$gbbin <- NULL
#' model <- rpart(data=x,formula=gb~.)
#' parse_tree(x,model)
parse.tree.to.sql <- function (df=NULL, model=NULL) { #https://gist.github.com/tomasgreif/6038822
  log <- capture.output({
    rpart.rules <- path.rpart(model,rownames(model$frame)[model$frame$var=="<leaf>"])
  })  

  args <- c("<=",">=","<",">","=")
  rules_out <- "case "
  i <- 1

  for (rule in rpart.rules) {  
    rule_out <- character(0)
    for (component in rule) {
      sep <- lapply(args, function(x) length(unlist(strsplit(component,x)))) > 1
      elements <- unlist(strsplit(component,(args[sep])[1]))
      if(!(elements[1]=="root")) {
        if (is.numeric(df[,elements[[1]]])) {
          rule_out <- c(rule_out,paste(elements[1],(args[sep])[1],elements[2]))
        } else {
          rule_out <- c(rule_out,paste0(elements[1]," in (",paste0("'",unlist(strsplit(elements[2],",")),"'",collapse=","),")"))
        }
      }
    }
    rules_out <- c(rules_out, paste0("when ", paste(rule_out,collapse=" AND ")," then ",
sprintf("%f /*node %s */",rpart.m$frame$yval[row.names(rpart.m$frame)==names(rpart.rules)[i] ],names(rpart.rules)[i]) ))
    if(i==length(rpart.rules)) rules_out <- c(rules_out," end ")
    i <- i +1
  }
  sql_out <- paste(rules_out, collapse=" ")
  sql_out  
}





results=parse.tree.to.sql(df=df,model=rpart.m)
cat(results) ##if you see truncated on output, you must configure lines output in rstudio

Among other things, this produces:

case  when y <  -1.449 then 0.012000 /*node 2 */ when y >= -1.449 AND y >= 3.327 then 0.050000 /*node 6 */ when y >= -1.449 AND y <  3.327 AND x >= 3.449 then 0.025641 /*node 14 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x <  -1.449 then 0.030769 /*node 30 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x <  2.469 AND y >= -0.3469 AND x >= -0.3469 then 0.079395 /*node 496 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x <  2.469 AND y >= -0.3469 AND x <  -0.3469 then 0.753623 /*node 497 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x <  2.469 AND y <  -0.3469 AND x <  -0.5918 then 0.174603 /*node 498 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x <  2.469 AND y <  -0.3469 AND x >= -0.5918 then 0.746667 /*node 499 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x >= 2.469 AND y <  -0.5918 then 0.142857 /*node 250 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y <  2.469 AND x >= 2.469 AND y >= -0.5918 then 0.775000 /*node 251 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y >= 2.469 AND x >= 2.714 then 0.095238 /*node 126 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y >= 2.469 AND x <  2.714 AND x <  -0.5918 then 0.163265 /*node 254 */ when y >= -1.449 AND y <  3.327 AND x <  3.449 AND x >= -1.449 AND y >= 2.469 AND x <  2.714 AND x >= -0.5918 then 0.814815 /*node 255 */  end 
Chris
  • 1,219
  • 2
  • 11
  • 21
  • 2
    It would be helpful if you provided a [reproducible example](https://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example) for testing with sample input data. We don't have `df` so we can't run this currently. – MrFlick May 30 '17 at 21:43
  • I posted code. I think I may have solved it, however, I suspect the case statement isn't as concise as a nested one. – Chris May 31 '17 at 00:46
  • I read that: SQL Server allows for only 10 levels of nesting in CASE expressions. I was able to make a nested case function with the undocumented rpart:::labels.rpart function but the non-nested version may be best due to the 10 levels restriction. – Chris Jun 01 '17 at 15:58
  • Replacing both "rpart.m" with "model" inside the parse.tree.to.sql function makes it self contained. rpart.m refers to the model outside of the function, which is unnecessary and hinders the function's usefulness. – ARobertson Aug 09 '23 at 23:39

0 Answers0