1

I have been working through the book Introduction to data science by Rafael A. Irizarry, and I keep coming across the decision boundary plots that I would like to recreate (left one in the image below)

I found a code to create decision boundary plots on https://mhahsler.github.io/Introduction_to_Data_Mining_R_Examples/book/classification-alternative-techniques.html#decision-boundaries , which does the job, but the plots don't look like the one in the book.

library(randomForest)
library(tidyverse)
library(caret)
library(dslabs)

decisionplot <- function(model, data, class = NULL, predict_type = "class",
                         resolution = 100, showgrid = TRUE, ...) {
  
  if(!is.null(class)) cl <- data[,class] else cl <- 1
  data <- data[,1:2]
  k <- length(unique(cl))
  
  plot(data, col = as.integer(cl)+1L, pch = as.integer(cl)+1L, ...)
  
  # make grid
  r <- sapply(data, range, na.rm = TRUE)
  xs <- seq(r[1,1], r[2,1], length.out = resolution)
  ys <- seq(r[1,2], r[2,2], length.out = resolution)
  g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
  colnames(g) <- colnames(r)
  g <- as.data.frame(g)
  
  ### guess how to get class labels from predict
  ### (unfortunately not very consistent between models)
  p <- predict(model, g, type = predict_type)
  if(is.list(p)) p <- p$class
  p <- as.factor(p)
  
  if(showgrid) points(g, col = as.integer(p)+1L, pch = ".")
  
  z <- matrix(as.integer(p), nrow = resolution, byrow = TRUE)
  contour(xs, ys, z, add = TRUE, drawlabels = FALSE,
          lwd = 2, levels = (1:(k-1))+.5)
  
  invisible(z)
}


train_rf<- randomForest(y~., data = mnist_27$train)

decisionplot(train_rf, data= mnist_27$train %>% select(x_1, x_2, y) , class="y")

I need assistance to make the decision boundary plots like in the book.

Michael Hahsler
  • 2,965
  • 1
  • 12
  • 16
Harkaran Saini
  • 139
  • 1
  • 11
  • 1
    Have you tried actually trying to recreate the plots? – NelsonGon May 01 '19 at 12:35
  • 1
    @NelsonGon I did have a look at the decisionplot function but couldn't figure out how to use ggplot in it. I'm still a beginner and comfortable using ggplot right out of the box but haven't learnt how to customize the plots. – Harkaran Saini May 01 '19 at 12:48
  • 1
    Take a look at [this](https://stackoverflow.com/questions/39822505/drawing-decision-boundaries-in-r). It might help. – NelsonGon May 01 '19 at 12:49
  • 1
    @NelsonGon I am trying to get the boundaries on the full multi-class problem (all the digits from 0 to 9). Do you have any suggestions for how I modify this code - below OP only compares predicted probs for digits 2 and 7, but I have all digits 0 to 9. – user2450223 May 19 '22 at 20:48
  • 1
    Sorry I am not on a computer right now and less active on Stackoverflow. I'll take a look sometime soon. – NelsonGon May 21 '22 at 18:24
  • @NelsonGon I have posted here: https://stackoverflow.com/questions/72339619/plotting-decision-boundary-for-a-multiclass-random-forest-model Could you take a look whenever you can? Thank you. – user2450223 May 23 '22 at 02:52

1 Answers1

3

Thanks Nelson. Saw your link and few other resources and got to this.

library(randomForest)
library(tidyverse)
library(caret)
library(dslabs)
library(ggthemes)

model<- randomForest(y~., data = mnist_27$train)
data<- mnist_27$train %>% select(x_1, x_2, y)
class<- "y"
#predict_type = "class"
resolution = 75


  if(!is.null(class)) cl <- data[,class] else cl <- 1
  data <- data[,1:2]




  r <- sapply(data, range, na.rm = TRUE)
  xs <- seq(r[1,1], r[2,1], length.out = resolution)
  ys <- seq(r[1,2], r[2,2], length.out = resolution)
  g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
  colnames(g) <- colnames(r)
  g <- as.data.frame(g)

  q<- predict(model, g, type = "class")
  p <- predict(model, g, type = "prob")
  p<- p %>% as.data.frame() %>% mutate(p=if_else(`2`>=`7`, `2`, `7`))
  p<- p %>% mutate(pred= as.integer(q))







  ggplot()+
    geom_raster(data= g, aes(x= x_1, y=x_2, fill=p$`2` ), interpolate = TRUE)+
    geom_contour(data= NULL, aes(x= g$x_1, y=g$x_2, z= p$pred), breaks=c(1.5), color="black", size=1)+
    theme_few()+
    scale_colour_manual(values = cols)+
    labs(colour = "", fill="")+
    scale_fill_gradient2(low="#338cea", mid="white", high="#dd7e7e", 
                         midpoint=0.5, limits=range(p$`2`))+
    theme(legend.position = "none")





enter image description here

Harkaran Saini
  • 139
  • 1
  • 11
  • Great question. I am trying to do the same with the full MNIST data with all the digits, in an RF model. How do I get the boundaries between all the class labels (the digits 0 to 9)? You have only compared the predictions for digits 2 and 7. Thank you! – user2450223 May 19 '22 at 20:46