3

I recently switched from Tensorflow 1.14 and Estimaror API to Tensorflow 2.0 and keras API.I am working on an image segmentation problem so the inputs/outputs/labels are all images. When I used Estimator, things where pretty straight forward. In model_fn where the arguments were (features, labels, mode, params) I could just pick the features and labels, do the necessary processing and then pass it in tf.summary.image() and everything worked like a charm. Now, using the keras API, although it provides greater ease of use, it makes hard to do simple handling on data during training, which becomes even harder when it is used with dataset API.Example: Tensorflow 1.14/Estimator:

def model_fn(features, labels, mode, params): 
    loss, train_op, = None, None
    eval_metric_ops, training_hooks, evaluation_hooks = None, None, None
    output = model(input=features)
    predictions = tf.argmax(output, axis=-1)
    predictions_dict = {'predicted': predictions}
    dice_score = tf.contrib.metrics.f1_score(labels=label, predictions=predictions[:, :, :, 1])
    if mode in (estimator.ModeKeys.TRAIN, estimator.ModeKeys.EVAL):
        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(params['lr'], global_step=global_step,
                                                   decay_steps=params['decay_steps'],
                                                   decay_rate=params['decay_rate'], staircase=False)
        loss = loss_fn(outputs=predictions, labels=labels)
        summary.image('Input_Image', features)
        summary.image('Label', tf.expand_dims(tf.cast(label, dtype=tf.float32), axis=-1))
        summary.image('Prediction', tf.expand_dims(tf.cast(predictions, dtype=tf.float32), axis=-1))
    if mode == estimator.ModeKeys.TRAIN:
        with tf.name_scope('Metrics'):
            summary.scalar('Dice_Coefficient', dice_score[1])
            summary.scalar('Learning_Rate', learning_rate)
        summary.merge_all()
        train_logs_hook = tf.estimator.LoggingTensorHook({'Dice_Coefficient': dice_score[1]},every_n_iter=params['train_log_every_n_steps'])                                                  every_n_iter=params['train_log_every_n_steps'])
        training_hooks = [train_logs_hook]
        train_op = Adam(learning_rate=learning_rate, epsilon=params['epsilon']).minimize(loss=loss, global_step=global_step)
    if mode == estimator.ModeKeys.EVAL:
        eval_metric_ops = {'Metrics/Dice_Coefficient': dice_score}
        eval_summary_hook = tf.estimator.SummarySaverHook(output_dir=params['eval_metrics_path'],
                                                          summary_op=summary.merge_all(),
                                                          save_steps=params['eval_steps_per_summary_save'])
        evaluation_hooks = [eval_summary_hook]
    return estimator.EstimatorSpec(mode,
                                   predictions=predictions_dict,
                                   loss=loss,
                                   train_op=train_op,
                                   eval_metric_ops=eval_metric_ops,
                                   training_hooks=training_hooks,
                                   evaluation_hooks=evaluation_hooks)

Using Keras with Tensorflow 2.0 AFAIK, I can't have this kind of access to the Input/Output tensors during training or evaluation (notice than even though during evaluation estimator dont get the image summaries, you can still have access to preview the results by using a tf.estimator.SummarySaverHook). Below is my falied attempt:

   def train_data(params):  # Similar is the eval_data
       def standardization_summaries(image, label, step, writer):
           # Some processing to images
           with writer.as_default():
               tf.summary.image('Input_dataset', image, step=step, max_outputs=1)
               tf.summary.image('label_dataset', label, step=step, max_outputs=1)
           return image, label
       data_set = tf.data.Dataset.from_generator(generator=lambda: data_generator(params),
                                              output_types=(tf.float32, tf.int64),
                                              output_shapes=(tf.TensorShape([None, None]), tf.TensorShape([None, None])))
       data_set = data_set.map(lambda x, y: standardization_summaries(image=x, label=y, step=params['global_step'], writer=params['writer']))
       data_set = data_set.batch(params['batch_size'])
       data_set = data_set.prefetch(buffer_size=-1)
       return data_set

    model = tf.keras.models.load_model(saved_model)
    summary_writer = tf.summary.create_file_writer(save_model_path)
    step = tf.Variable(0, trainable=False, dtype=tf.int64)
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir=save_model_path, histogram_freq=1, write_graph=True,
                                                 write_images=False)
    early_stop = tf.keras.callbacks.EarlyStopping(patience=args.early_stop)
    callbacks = [tensorboard, early_stop]
    params = {'batch_size': args.batch_size,
                       'global_step': step,
                       'writer': summary_writer}

    model.fit(x=train_data(params), epochs=args.epochs, initial_epoch=args.initial_epoch,
              validation_data=val_data(params), steps_per_epoch=2, callbacks=callbacks)

Getting the input images from the dataset API came from here but this just gets tons of images whenever the dataset fetches data from the generator. Also, with the step variable being constant and not changing (I can't figure out how to make it walk) everything is just under the step 0 and I can't think any viable way to connect these outputs with the predicted output, given that I would find a way to print them. So, the question is: Is there anything that I am still missing with Keras API and Tensorboard synergies on image summaries. Is there a way to save image summaries lets say, for every half epoch in training and once at the end of evaluation or should I just let the model be trained and get the training outputs through model.predict() at the end of training an then inspect if something goes wrong(which is not efficient)?

Georgios Livanos
  • 506
  • 3
  • 17

0 Answers0