1

I came across this notebook that covers forecasting. I got it through this article.

I am confused about the 2nd and 4th line from below

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()

val_data = tf.data.Dataset.from_tensor_slices((x_vali, y_vali))
val_data = val_data.batch(batch_size).repeat()

I understand that we are trying to shuffle our data as we dont want to feed data to our model in the serial order. On additional reading I realized that it is better to have buffer_size same as the size of the dataset. But I am not sure what repeat is doing in this case. Could someone explain what is being done here and what is the function of repeat?

I also looked at this page and saw below text but still not clear.

The following methods in tf.Dataset :

repeat( count=0 ) The method repeats the dataset count number of times.
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.
user2543622
  • 5,760
  • 25
  • 91
  • 159

1 Answers1

1

The repeat call with nothing passed to the count param makes this dataset repeat infinitely.

In python terms, Datasets are a subclass of python iterables. If you have an object ds of type tf.data.Dataset, then you can execute iter(ds). If the dataset was generated by repeat(), then it will never run out of items, i.e., it will never throw a StopIteration exception.

In the notebook you referenced, the call to tf.keras.Model.fit() is passed an argument of 100 to the param steps_per_epoch. This means that the dataset should be infinitely repeating, and Keras will pause training to run validation every 100 steps.

tldr: leave it in.

https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/tensorflow/python/data/ops/dataset_ops.py#L134-L3445

https://docs.python.org/3/library/exceptions.html

Yaoshiang
  • 1,713
  • 5
  • 15
  • in case of this specific example, is it safe to delete `repeat` part then? it is necessary ? – user2543622 Apr 18 '22 at 17:27
  • 1
    It depends on your training loop. If your training loop is expecting an infinite dataset, then you need to keep the repeat. Usually, people train with a number of epochs where an epoch is a full sweep of the dataset, so in that case you would remove the repeat. – Yaoshiang Apr 19 '22 at 23:27
  • thanks, I am running for 10 epochs and will remove `repeat` part. do you agree with it is better to have `buffer_size` same as the size of the dataset? any thoughts on what should be my `buffer_size`? – user2543622 Apr 20 '22 at 00:42
  • 1
    You want a buffer size as big as physically possible to get the best shuffle. If you are using TFRecords as the underlying disk storage, it is a Google technology called RecordIO which was designed for spinning magnetic platter disks. With SSD disks, the whole TFRecord -> tf.data.Dataset should be revamped, but it's too ingrained into the API. I'm working on a loader that will shuffle via random access to a TFRecord, rather than serially loading it and performing a weak shuffle. Determined released YogaDL in the same vein. – Yaoshiang Apr 20 '22 at 01:35