Declaring a Keras model as reactiveVal
causes Shiny to not recognize when it is mutated during a call to fit()
. What is the reason for this and what's the best workaround? Is it possible to manually tell Shiny that some reactive variable has been mutated?
Keras models are generally mutated in-place and I suspect that this circumstance causes Shiny to miss whenever the model's parameters are updated during a training run. Or might reason be that Keras is an interface to TensorFlow (i.e. some portion of the code is executed outside the scope of R and Shiny)?
In this self-contained MWE an arbitrary weight is displayed. Although the model is trained the displayed value never updates (the weight does indeed change and is printed to the shell after each epoch of training).
library(keras)
library(shiny)
# Load data, reshape and normalize inputs and one-hot-encode labels
mnist <- dataset_mnist()
x.train <- array_reshape(mnist$train$x, c(nrow(mnist$train$x), 784)) / 255
y.train <- to_categorical(mnist$train$y, 10)
# Returns untrained Keras model. Don't worry about the details.
initialize_model <- function() {
keras_model_sequential() %>%
layer_dense(units = 5, activation = 'relu', input_shape = c(784)) %>%
layer_dense(units = 10, activation = 'softmax') %>%
compile(loss = 'categorical_crossentropy', optimizer = optimizer_rmsprop())
}
ui <- fluidPage(
fluidRow(
column(width = 12, align = "center",
actionButton('train.model', label = 'Train Model'),
br(), br(),
code(textOutput('random.weight'))
)
)
)
server <- function(input, output) {
MODEL <- reactiveVal(initialize_model())
RANDOM.WEIGHT <- reactive({ return(MODEL()$get_weights()[[3]][1,2]) })
observeEvent(input$train.model, {
# this causes the model object to be mutated in-place but RANDOM.WEIGHT
# is never updated
MODEL() %>% fit(x.train, y.train, epochs = 1, batch_size = 60000)
cat("The weight's new value is = ", MODEL()$get_weights()[[3]][1,2], "\n")
})
output$random.weight <- renderText({
return(paste0("RANDOM.WEIGHT() = ", RANDOM.WEIGHT()))
})
}
app <- shinyApp(ui = ui, server = server)
runApp(app, port = 1337)
Any ideas? Thanks
Response to comments:
What's the output of dput(initialize_model())
?
structure(function (object)
{
compose_layer(object, x)
}, class = c("keras.engine.sequential.Sequential",
"keras.engine.training.Model",
"keras.engine.network.Network",
"keras.engine.base_layer.Layer",
"tensorflow.python.training.checkpointable.base.CheckpointableBase",
"python.builtin.object"), py_object = <environment>)
Unfortunately, that doesn't really tell me much.