1

I am using svm to do multi-class classification using e1071 package in R.

I want cross validated probability predictions for each class and each data point in the training set. i.e I want N x K cross validation probability matrix.

Can someone tell me how to do it?

mathkid
  • 347
  • 3
  • 12

1 Answers1

0

A few points:

(1) With cross-validation you measure the accuracy of your model (trained on the training dataset) on the held-out dataset, not on the entire dataset.

(2) You need to select the values of the hyper-parameters (C, gamma) before you compute the matrix.

(3) you can use caret package to compute the desired probability matrix, but since it's multiclass classification problem, you need to choose which class you want to compute the probability for, before you compute the matrix.

Use the following code on iris, which has 150 data points, out of which 15 points will be randomly selected as validation data for each fold. Let's find the probability the predicted class is setosa and compute the 150x11 matrix, where the last column is a binary column representing whether the actual class of the data point is setosa or not.

K <- 10 # number of folds        
set.seed(123)        
library(caret)
library(reshape2)  
trctl <- trainControl(method = "cv", number = K, savePredictions = TRUE, classProbs = TRUE)
res <- train(Species ~ ., data = iris, method="svmRadial", trControl = trctl)
res.C1 <- subset(res$pred, C==1)
head(res.C1) 

       pred        obs        setosa  versicolor   virginica rowIndex    sigma C Resample
31     setosa     setosa 0.980011940 0.009115859 0.010872201       17 1.421405 1   Fold01
32     setosa     setosa 0.872285443 0.051664831 0.076049726       23 1.421405 1   Fold01
33     setosa     setosa 0.983836684 0.007452339 0.008710978       35 1.421405 1   Fold01
34     setosa     setosa 0.956874365 0.018767699 0.024357936       38 1.421405 1   Fold01
35     setosa     setosa 0.979355342 0.009425609 0.011219049       39 1.421405 1   Fold01
36 versicolor versicolor 0.009445829 0.935110658 0.055443514       55 1.421405 1   Fold01

cbind.data.frame(round(dcast(res.C1, rowIndex~Resample, value.var = 'setosa'),2), setosa=res.C1$obs=='setosa')

    rowIndex Fold01 Fold02 Fold03 Fold04 Fold05 Fold06 Fold07 Fold08 Fold09 Fold10 setosa
1          1     NA     NA     NA     NA     NA     NA     NA     NA     NA   0.99   TRUE
2          2     NA     NA     NA     NA     NA     NA     NA     NA   0.98     NA   TRUE
3          3     NA     NA     NA     NA     NA   0.98     NA     NA     NA     NA   TRUE
4          4     NA     NA     NA     NA     NA     NA   0.98     NA     NA     NA   TRUE
5          5     NA     NA     NA   0.99     NA     NA     NA     NA     NA     NA   TRUE
6          6     NA   0.98     NA     NA     NA     NA     NA     NA     NA     NA  FALSE
7          7     NA     NA     NA     NA   0.97     NA     NA     NA     NA     NA  FALSE
8          8     NA     NA   0.99     NA     NA     NA     NA     NA     NA     NA  FALSE
9          9     NA   0.96     NA     NA     NA     NA     NA     NA     NA     NA  FALSE
10        10     NA   0.98     NA     NA     NA     NA     NA     NA     NA     NA  FALSE
#         ...   ...
145      145     NA     NA     NA     NA     NA     NA     NA     NA   0.01     NA  FALSE
146      146     NA     NA     NA   0.01     NA     NA     NA     NA     NA     NA  FALSE
147      147     NA     NA     NA   0.01     NA     NA     NA     NA     NA     NA  FALSE
148      148     NA     NA     NA     NA     NA     NA     NA     NA     NA   0.01  FALSE
149      149     NA     NA     NA     NA     NA     NA     NA     NA   0.02     NA  FALSE
150      150     NA     NA     NA     NA     NA     NA     NA   0.01     NA     NA  FALSE
Sandipan Dey
  • 21,482
  • 2
  • 51
  • 63
  • Did you notice any inconsistency for the predicted probability compared to class labels for SVM ? I found some inconsistencies especially when predicted probability is close to 0.5. Please check this question : https://stackoverflow.com/questions/63749263/different-results-for-svm-in-r – student_R123 Sep 13 '20 at 19:17