2

I have just learnt about the KNN algorithm and machine learning. It is a lot for me to take in and we are using tidymodels in R to practice.

Now, I know how to implement a grid search using k-fold cross-validation as follows:

hist_data_split <- initial_split(hist_data, strata = fraud)
hist_data_train <- training(hist_data_split)
hist_data_test <- testing(hist_data_split)
folds <- vfold_cv(hist_data_train, strata = fraud)
nearest_neighbor_grid <- grid_regular(neighbors(range = c(1, 500)), levels = 25)
knn_rec_1 <- recipe(fraud ~ ., data = hist_data_train)
knn_spec_1 <- nearest_neighbor(mode = "classification", engine = "kknn", neighbors = tune(), weight_func = "rectangular")
knn_wf_1 <- workflow(preprocessor = knn_rec_1, spec = knn_spec_1)
knn_fit_1 <- tune_grid(knn_wf_1, resamples = folds, metrics = metric_set(accuracy, sens, spec, roc_auc), control = control_resamples(save_pred = T), grid = nearest_neighbor_grid)

In the above case, I am essentially running a 10-fold cross-validated grid search to tune my model. However, the size of hist_data is 169173, which gives an optimal K of about 411 and with a 10-fold cross-validation, the tuning is going to take forever, so the hint given is to use a single validation fold instead of cross-validation.

Thus, I am wondering how I can tweak my code to implement this. When I add the argument v = 1 in vfold_cv, R throws me an error which says, "At least one row should be selected for the analysis set." Should I instead change resamples = folds in tune_grid to resamples = 1?

Any intuitive suggestions will be greatly appreciated :)

P.S. I did not include an MWE in the sense that the data is not provided because I feel like this is a really trivial question which can be answered as is!

Ethan Mark
  • 293
  • 1
  • 9

1 Answers1

4

If you are not able to do a cross validation split, for whatever reason, you can do a validation split which conceptually is very close to a v = 1 cross validation.

library(tidymodels)

hist_data_split <- initial_split(ames, strata = Street)
hist_data_train <- training(hist_data_split)
hist_data_test <- testing(hist_data_split)

folds <- validation_split(hist_data_train, strata = Street)

nearest_neighbor_grid <- grid_regular(
  neighbors(range = c(1, 500)), 
  levels = 25
)

knn_rec_1 <- recipe(Street ~ ., data = ames)
knn_spec_1 <- nearest_neighbor(neighbors = tune()) %>%
  set_mode("classification") %>%
  set_engine("kknn") %>%
  set_args(weight_func = "rectangular")

knn_wf_1 <- workflow(preprocessor = knn_rec_1, spec = knn_spec_1)

knn_fit_1 <- tune_grid(
  knn_wf_1,
  resamples = folds,
  metrics = metric_set(accuracy, sens, spec, roc_auc),
  control = control_resamples(save_pred = T),
  grid = nearest_neighbor_grid
)

knn_fit_1
#> # Tuning results
#> # Validation Set Split (0.75/0.25)  using stratification 
#> # A tibble: 1 × 5
#>   splits             id         .metrics           .notes           .predictions
#>   <list>             <chr>      <list>             <list>           <list>      
#> 1 <split [1647/550]> validation <tibble [100 × 5]> <tibble [0 × 3]> <tibble>

knn_fit_1 %>%
  collect_metrics()
#> # A tibble: 100 × 7
#>    neighbors .metric  .estimator  mean     n std_err .config              
#>        <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
#>  1         1 accuracy binary     0.996     1      NA Preprocessor1_Model01
#>  2         1 roc_auc  binary     0.5       1      NA Preprocessor1_Model01
#>  3         1 sens     binary     0         1      NA Preprocessor1_Model01
#>  4         1 spec     binary     1         1      NA Preprocessor1_Model01
#>  5        21 accuracy binary     0.996     1      NA Preprocessor1_Model02
#>  6        21 roc_auc  binary     0.495     1      NA Preprocessor1_Model02
#>  7        21 sens     binary     0         1      NA Preprocessor1_Model02
#>  8        21 spec     binary     1         1      NA Preprocessor1_Model02
#>  9        42 accuracy binary     0.996     1      NA Preprocessor1_Model03
#> 10        42 roc_auc  binary     0.486     1      NA Preprocessor1_Model03
#> # … with 90 more rows

Created on 2022-09-06 by the reprex package (v2.0.1)

EmilHvitfeldt
  • 2,555
  • 1
  • 9
  • 12