0

I'm working on a text classification project, and I've been doing everything under the tidymodels framework. Right now, I'm trying to investigate whether or not particular data points are being consistently mislabeled across the board. To do this, I want to get into the saved predictions for individual samples. When I perform resampling and use collect_predictions, while I see a list that contains the predicted label and the actual label for each of the data points, the identity of the data points themselves are still hidden. There's one column that may trace back (.row), but I'm having trouble confirming this.

I've been generating my resampling strategy as follows:

grades_split <- initial_split(tabled_texts2, strata = grade)

grades_train <- training(grades_split)
grades_test <- testing(grades_split)

folds <- vfold_cv(grades_train)

Then, after tuning and fitting the model, I generate the resamples object:

fitted_grades <- fit(final_wf, grades_train)

LR_rs <- fit_resamples(
  fitted_grades,
  folds,
  control = control_resamples(save_pred = TRUE)
)

Finally, I examine the predictions like this:

predictions <- collect_predictions(LR_rs)
View(predictions)

I get a table that looks like this:

id .pred_4 .pred_not 4 .row .pred_class grade .config
Fold01 0.502905 0.497095 18 4 4 Preprocessor1_Model1
Fold01 0.484647 0.515353 22 not 4 4 Preprocessor1_Model1
Fold01 0.481496 0.518504 23 not 4 4 Preprocessor1_Model1
Fold01 0.492314 0.507686 40 not 4 4 Preprocessor1_Model1
Fold01 0.477215 0.522785 52 not 4 4 Preprocessor1_Model1

How could I map these values back to the original data?

Here is an analogous reprex. In this example, I would like to be able to see specifically which of the penguins are being misclassified, not just an arbitrary .row value (which I'm pretty sure doesn't map back 1-1 to the original dataset)

library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)

data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
  set_mode('classification') %>%
  set_engine('naivebayes')

fitted_wf <- workflow() %>%
  add_formula(species ~ island + flipper_length_mm) %>%
  add_model(nb_spec) %>%
  fit(penguins)


split <- initial_split(penguins)

train <- training(split)
test <- testing(split)

folds <- vfold_cv(train)

NB_rs <- fit_resamples(
  fitted_wf,
  folds,
  control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)
  • 2
    Are you trying to get the predictions for each of your crossfold validations? Or have you somehow chosen a "final" model you want to use to make predictions on? It's easier to help you if you include a simple [reproducible example](https://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example) with sample input and desired output that can be used to test and verify possible solutions. – MrFlick Jul 22 '21 at 18:08

1 Answers1

0

The .row column does in fact tell you which row each of these predictions is, from the training dataset. Let's see if we can convince you of this:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
set.seed(1)

data("penguins")
nb_spec <- naive_Bayes() %>%
   set_mode('classification') %>%
   set_engine('naivebayes')

fitted_wf <- workflow() %>%
   add_formula(species ~ island + flipper_length_mm) %>%
   add_model(nb_spec) 

split <- penguins %>%
   na.omit() %>%
   initial_split()

penguin_train <- training(split)
penguin_test <- testing(split)

folds <- vfold_cv(penguin_train)

NB_rs <- fit_resamples(
   fitted_wf,
   folds,
   control = control_resamples(save_pred = TRUE)
)

predictions <- collect_predictions(NB_rs)

Let's look at just one of the folds:

predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#>    id     .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species  
#>    <chr>         <dbl>           <dbl>        <dbl> <int> <fct>       <fct>    
#>  1 Fold01     0.609        0.391        0.000000526     3 Adelie      Adelie   
#>  2 Fold01     0.182        0.818        0.000104        8 Chinstrap   Adelie   
#>  3 Fold01     0.423        0.577        0.000000325     9 Chinstrap   Chinstrap
#>  4 Fold01     0.999        0.00120      0.00000137     21 Adelie      Adelie   
#>  5 Fold01     0.000178     0.0000310    1.00           27 Gentoo      Gentoo   
#>  6 Fold01     0.552        0.448        0.000000395    36 Adelie      Adelie   
#>  7 Fold01     0.997        0.000392     0.00275        45 Adelie      Adelie   
#>  8 Fold01     0.000211     0.000000780  1.00           48 Gentoo      Gentoo   
#>  9 Fold01     0.998        0.00129      0.00114        60 Adelie      Adelie   
#> 10 Fold01     0.00313      0.000100     0.997          79 Gentoo      Gentoo   
#> # … with 15 more rows, and 1 more variable: .config <chr>

This has row 3, 8, 9, etc. It is the assessment set of the first resample in folds.

Now let's look at the training data:

penguin_train
#> # A tibble: 249 × 7
#>    species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>     <fct>           <dbl>         <dbl>             <int>       <int>
#>  1 Chinstrap Dream            50.2          18.8               202        3800
#>  2 Gentoo    Biscoe           50.2          14.3               218        5700
#>  3 Adelie    Dream            38.1          17.6               187        3425
#>  4 Chinstrap Dream            51            18.8               203        4100
#>  5 Chinstrap Dream            52.7          19.8               197        3725
#>  6 Gentoo    Biscoe           49.6          16                 225        5700
#>  7 Chinstrap Dream            46.2          17.5               187        3650
#>  8 Adelie    Dream            35.7          18                 202        3550
#>  9 Chinstrap Dream            51.7          20.3               194        3775
#> 10 Gentoo    Biscoe           50.4          15.7               222        5750
#> # … with 239 more rows, and 1 more variable: sex <fct>

Created on 2021-07-30 by the reprex package (v2.0.0)

Look at row 3, 8, and 9; the species match up because these are the same rows!

Do be aware that you may get different predictions for each fold in folds, because they have different training sets, what we call analysis sets.

Julia Silge
  • 10,848
  • 2
  • 40
  • 48
  • This was perfect, thank you so much! My mistake was that I was trying to match the .row information to the original dataset, when I should have been looking at the training set. – Paul Braymen Aug 01 '21 at 13:00