8

After designing a Tidymodels recipe-based workflow, which is tuned then fitted to some training data, I'm not clear what objects (fitted "workflow", "recipe", ..etc) should be saved to disk for use in predicting new data in production. I understand I can use saveRDS()/readRDS(), write_rds()/read_rds(), or other options to actually do the saving/loading of these objects, but which ones?

In a clean R environment I will have incoming new raw data which will need pre-processed using the "recipe" I used in training the model. I then want to make predictions based on that data after it has been pre-processed. If I intend to use the prep() and bake() functions to pre-process the new data as I did the training data, then I will minimally need the recipe and original training data it seems to get prep() to work. Plus, I also need the fitted model/workflow to make predictions. So three objects it seems. If I save to disk the workflow object in SESSION 1 then I have the ability to extract the recipe and model from it in SESSION 2 with pull_workflow_prepped_recipe() and pull_workflow_fit() respectively. But prep() seems to require the original training data, which I can keep in the workflow with an earlier use of retain = TRUE...but then that gets stripped out of the workflow after a call to fit(). Hear my cries for help! :)

So, imagine two different R sessions, where the first session I am doing all the training and model building, and the second session is some running production app that uses what was learned from the first session. I need help at the arrows in the bottom of SESSION1, and in multiple places in SESSION 2. I used the Tidymodels Get Started as the base for this example.

SESSION 1

library(tidymodels)
library(nycflights13)
library(readr)
set.seed(123)
flight_data <- 
  head(flights, 500) %>% 
  mutate(
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    date = as.Date(time_hour)
  ) %>% 
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  select(dep_time, flight, origin, dest, air_time, distance, carrier, date, arr_delay, time_hour) %>% 
  na.omit() %>% 
  mutate_if(is.character, as.factor)

set.seed(555)
data_split <- initial_split(flight_data, prop = 3/4)
train_data <- training(data_split)
test_data  <- testing(data_split)

flights_rec <- 
  recipe(arr_delay ~ ., data = train_data) %>% 
  update_role(flight, time_hour, new_role = "ID") %>% 
  step_date(date, features = c("dow", "month")) %>%
  step_holiday(date, holidays = timeDate::listHolidays("US")) %>%
  step_rm(date) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>%
  step_zv(all_predictors())

lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")

flights_wflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(flights_rec)

flights_fit <- 
  flights_wflow %>% 
  fit(data = train_data)

predict(flights_fit, test_data)

### SAVE ONE OR MORE OBJECTS HERE FOR NEXT SESSION  <------------
# What to save? workflow (pre or post fit()?), recipe, training data...etc.
write_rds(flights_wflow, "flights_wflow.rds")  # Not fitted workflow
write_rds(flights_fit, "flights_fit.rds")  # Fitted workflow

SESSION 2

### READ ONE OR MORE OBJECTS HERE FROM PRIOR SESSION <------------
flights_wflow <- read_rds("flights_wflow.rds")
flights_fit <- read_rds("flights_fit.rds")

# Acquire new data, do some basic transforms as before
new_flight_data <- 
  tail(flights, 500) %>% 
  mutate(
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    date = as.Date(time_hour)
  ) %>% 
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  select(dep_time, flight, origin, dest, air_time, distance, carrier, date, arr_delay, time_hour) %>% 
  na.omit() %>% 
  mutate_if(is.character, as.factor)

# Something here to preprocess the data with recipe as in SESSION 1 <----------
# new_flight_data_prep <- prep(??)
# new_flight_data_preprocessed <- bake(??)

# Predict new data
predict(flights_fit, new_data = new_flight_data_preprocessed)
wphampton
  • 504
  • 5
  • 13

2 Answers2

9

You have some flexibility in how you approach this, depending on your constraints, but generally I would recommend saving/serializing the fitted workflow, perhaps after using butcher to reduce its size. You can see an example model fitting script in this repo that shows at the end how I save the fitted workflow.

When you go to predict with this workflow, there are some things to keep in mind. I have an example Plumber API in the same repo that demonstrates what is needed to predict for that particular workflow. Notice that the packages how the package needed for prediction are loaded/attached for this API. I didn't use all of tidymodels, but instead only the specific packages I need, for better performance and a smaller container.

Julia Silge
  • 10,848
  • 2
  • 40
  • 48
0

Saving the fitted workflow did not work for me. When trying to predict with new data is asking for the target variable (a churn model)

predict(churn_model, the_data)
Error: Problem with `mutate()` column `churn`.
i `churn = dplyr::if_else(churn == 1, "yes", "no")`.
x object 'churn' not found

I still don't get why is asking for a column that should not be present in the data as it is the variable I try to predict...

Forge
  • 1,587
  • 1
  • 15
  • 36