5

I'm want to train my model with Keras. I'm using a huge dataset Where one training epoch has more than 30000 steps. My problem is that I don't want to wait for an epoch before checking the model improvement on the validation dataset. Is there any way to make Keras evaluate the validation data every 1000 steps of the training data? I think one option will be to use a callback but is there any built-in solution with Keras?

if train:
    log('Start training')
    history = model.fit(train_dataset,
                      steps_per_epoch=train_steps,
                      epochs=50,
                      validation_data=val_dataset,
                      validation_steps=val_steps,
                      callbacks=[
                            keras.callbacks.EarlyStopping(
                                monitor='loss',
                                patience=10,
                                restore_best_weights=True,
                            ),
                            keras.callbacks.ModelCheckpoint(
                                filepath=f'model.h5',
                                monitor='val_loss',
                                save_best_only=True,
                                save_weights_only=True,
                            ),
                            keras.callbacks.ReduceLROnPlateau(
                                monitor = "val_loss", 
                                factor = 0.5, 
                                patience = 3, 
                                min_lr=0.001,
                            ),
                        ],
                )
Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
Jude TCHAYE
  • 434
  • 5
  • 14

1 Answers1

6

With the in-built callbacks, you cannot do that. What you need is to implement a custom callback.

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

This is taken from the TensorFlow documentation.

You can override the on_train_batch_end() function and, since the batch parameter is an integer, you can verify batch % 100 == 0, then self.model.predict(val_data) etc. to your needs.

Please check my answer here: How to get other metrics in Tensorflow 2.0 (not only accuracy)? to have a good overview on how to override a custom callback function. Please note that in your case it is the on_train_batch_end() not on_epoch_end() that is important.

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
  • 1
    I wanted to avoid writing a custom callback because I didn't know it was so easy️. I made it work with the detail you gave. Thank you very much. – Jude TCHAYE Jun 07 '20 at 15:40