3

Setup:

  • U-Net network is trained to process small patches (e.g. 64x64 pixels).
  • The network is fed with a training dataset and validation dataset using Tensorflow Dataset API.
  • Small patches are generated by sampling (randomly) much larger images.
  • The sampling of image patches takes place during the training process (both training and validation image patches are cropped on the fly).
  • Tensorflow 2.1 (eager execution mode)

Both training and validation datasets are the same:

dataset = tf.data.Dataset.from_tensor_slices((large_images, large_targets))
dataset = dataset.shuffle(buffer_size=num_large_samples)
dataset = dataset.map(get_patches_from_large_images, num_parallel_calls=num_parallel_calls)
dataset = dataset.unbatch()
dataset = dataset.shuffle(buffer_size=num_small_patches)
dataset = dataset.batch(patches_batch_size)
dataset = dataset.prefetch(1)
dataset = dataset.repeat()

Function get_patches_from_large_images samples a predefined number of small patches from a single large image using tf.image.random_crop. There are two nested loops for and while. The outer loop for is responsible for generating the predefined number of small patches and while is used to check if randomly generated patch using tf.image.random_crop meets some predefined criteria (e.g. patches containing only the background should be discarded). The inner loop while gives up if it is not able to generate a proper patch in some predefined number of iterations so we will not get stuck in this loop. This approach is based on the solution presented here.

for i in range(number_of_patches_from_one_large_image):
    num_tries = 0
    patches = []
    while num_tries < max_num_tries_befor_giving_up:
          patch = tf.image.random_crop(large_input_and_target_image,[patch_size, patch_size, 2])
          if patch_meets_some_criterions:
             break
          num_tries = num_tries + 1
   patches.append(patch)

Experiment:

  • training and validation datasets to feed the model are the same (5 large pairs of input-target images), both datasets produce exactly the same number of small patches from single large image
  • batch_size for training and validation is the same and equals to 50 image patches,
  • steps_per_epoch and validation_steps are equal (20 batches)

When training is run for validation_freq=5

unet_model.fit(dataset_train, epochs=10, steps_per_epoch=20, validation_data = dataset_val, validation_steps=20, validation_freq=5)


Train for 20 steps, validate for 20 steps
Epoch 1/10
20/20 [==============================] - 44s 2s/step - loss: 0.6771 - accuracy: 0.9038
Epoch 2/10
20/20 [==============================] - 4s 176ms/step - loss: 0.4952 - accuracy: 0.9820
Epoch 3/10
20/20 [==============================] - 4s 196ms/step - loss: 0.0532 - accuracy: 0.9916
Epoch 4/10
20/20 [==============================] - 4s 194ms/step - loss: 0.0162 - accuracy: 0.9942
Epoch 5/10
20/20 [==============================] - 42s 2s/step - loss: 0.0108 - accuracy: 0.9966 - val_loss: 0.0081 - val_accuracy: 0.9975
Epoch 6/10
20/20 [==============================] - 1s 36ms/step - loss: 0.0074 - accuracy: 0.9978
Epoch 7/10
20/20 [==============================] - 4s 175ms/step - loss: 0.0053 - accuracy: 0.9985
Epoch 8/10
20/20 [==============================] - 3s 169ms/step - loss: 0.0034 - accuracy: 0.9992
Epoch 9/10
20/20 [==============================] - 3s 171ms/step - loss: 0.0023 - accuracy: 0.9995
Epoch 10/10
20/20 [==============================] - 43s 2s/step - loss: 0.0016 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9998

we can see that the first epoch and epochs with validation (every 5th epoch) took much more time than epochs without validation. The same experiment but this time validation is run each epoch give us the following result:

history = unet_model.fit(dataset_train, epochs=10, steps_per_epoch=20, validation_data = dataset_val, validation_steps=20)
Train for 20 steps, validate for 20 steps
Epoch 1/10
20/20 [==============================] - 84s 4s/step - loss: 0.6775 - accuracy: 0.8971 - val_loss: 0.6552 - val_accuracy: 0.9542
Epoch 2/10
20/20 [==============================] - 41s 2s/step - loss: 0.5985 - accuracy: 0.9833 - val_loss: 0.4677 - val_accuracy: 0.9951
Epoch 3/10
20/20 [==============================] - 43s 2s/step - loss: 0.1884 - accuracy: 0.9950 - val_loss: 0.0173 - val_accuracy: 0.9948
Epoch 4/10
20/20 [==============================] - 44s 2s/step - loss: 0.0116 - accuracy: 0.9962 - val_loss: 0.0087 - val_accuracy: 0.9969
Epoch 5/10
20/20 [==============================] - 44s 2s/step - loss: 0.0062 - accuracy: 0.9979 - val_loss: 0.0051 - val_accuracy: 0.9983
Epoch 6/10
20/20 [==============================] - 45s 2s/step - loss: 0.0039 - accuracy: 0.9989 - val_loss: 0.0033 - val_accuracy: 0.9991
Epoch 7/10
20/20 [==============================] - 44s 2s/step - loss: 0.0025 - accuracy: 0.9994 - val_loss: 0.0023 - val_accuracy: 0.9995
Epoch 8/10
20/20 [==============================] - 44s 2s/step - loss: 0.0019 - accuracy: 0.9996 - val_loss: 0.0017 - val_accuracy: 0.9996
Epoch 9/10
20/20 [==============================] - 44s 2s/step - loss: 0.0014 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9997
Epoch 10/10
20/20 [==============================] - 45s 2s/step - loss: 0.0012 - accuracy: 0.9998 - val_loss: 0.0011 - val_accuracy: 0.9998

Question: In the first example, we can see that the initialization/creation of the training data set (dataset_train) took about 40s. However, subsequent epochs (without validation) were shorter and took about 4s. Nevertheless, the duration was extended again to about 40 seconds for the epoch with the validation step. Validation dataset (dataset_val) is exactly the same as the training dataset (datasat_train) so the procedure of its creation/initialization took about 40s. However, I am surprised that each validation step is time expensive. I expected the first validation to take 40s, but the next validations should take about 4s. I thought that the validation dataset will behave like the training dataset so the first fetch will take long but subsequent should be much shorter. Am I right or maybe I'm missing something?

Update: I have checked that creating the iterator from the dataset takes about 40s

dataset_val_it = iter(dataset_val) #40s

If we look inside the fit function, we will see that data_handler object is created once for the whole training, and it returns the data iterator that is used in the main loop of the training process. The iterator is created by calling the function enumerate_epochs. When the fit function wants to perform the validation process, it calls the evaluate function. Whenever evaluate function is called it creates new data_handler object. And then it calls enumerate_epochs function what in turn creates the iterator from the dataset. Unfortunately, in the case of complicated datasets, this process is time-consuming.

1 Answers1

2

If you want just want a quickfix to speed up your input pipeline, you can try caching the elements of the validation dataset.

If we look inside the fit function, we will see that data_handler object is created once for the whole training, and it returns the data iterator that is used in the main loop of the training process. The iterator is created by calling the function enumerate_epochs. When the fit function wants to perform the validation process, it calls the evaluate function. Whenever evaluate function is called it creates new data_handler object. And then it calls enumerate_epochs function what in turn creates the iterator from the dataset. Unfortunately, in the case of complicated datasets, this process is time-consuming.

I've never dug very deep in the tf.data code, but you seem to make a point here. I think it can be interesting to open an issue on Github for this.

Kh4zit
  • 2,629
  • 3
  • 13
  • 25