[EDIT: I understand that it is faster also because the function is written in C, but I want to know if It does a brute force search on all the training instances or something more sophisticated ]
I'm implementing in R, for studying purpose, the KNN algorithm. I'm also checking the code correctness by comparison with the caret implementation.
The problem lies on the execution time of the two versions. My version seems to take a lot of time, instead the caret implementation is very fast (even with crossvalidation with 10 folds).
Why? I'm calculating every euclidean distance of my test instances from the training ones. Which means that I'm doing NxM distance calculation (where N are my test instances, and M my training instances):
for (i in 1:nrow(test)){
distances <- c()
classes <- c()
for(j in 1:nrow(training)){
d = calculateDistance(test[i,], training[j,])
distances <- c(distances, d)
classes <- c(classes, training[j,][[15]])
}
}
Is the caret implementation using some approximate search? Or an exact search, for example with the kd-tree? How can I speed up the search? I got 14 features for the problem, but I've been reading that the kd-tree is suggested for problem with 1 to 5 features.
EDIT:
I've found the C function called by R (VR_knn), which is pretty complex for me to understand, maybe someone can help.
Anyway I've written on the fly a brute force search in cpp, which seems to go faster than my previous R version, (but not fast as the caret C version) :
#include <Rcpp.h>
using namespace Rcpp;
double distance(NumericVector x1, NumericVector x2){
int vectorLen = x1.size();
double sum = 0;
for(int i=0;i<vectorLen-1;i++){
sum = sum + pow((x1.operator()(i)-x2.operator()(i)),2);
}
return sqrt(sum);
}
// [[Rcpp::export]]
void searchCpp(NumericMatrix training, NumericMatrix test) {
int numRowTr = training.rows();
int numColTr = training.cols();
int numRowTe = test.rows();
int numColTe = test.cols();
for (int i=0;i<numRowTe;i++)
{
NumericVector test_i = test.row(i);
NumericVector distances = NumericVector(numRowTe);
for (int j=0;j<numRowTr;j++){
NumericVector train_j = training.row(j);
double dist = distance(test_i, train_j);
distances.insert(i,dist);
}
}
}