I am trying to train a neural network model with TF2.3 on GPU (EC2/P2.xlarge with 61GB memory).
My code for data pipeline:
my_dataset = tf.data.TFRecordDataset(train_files_path, compression_type='GIP') # 500 files, each file size is 500+ KB
my_dataset = my_dataset.map(lambda x: parse_input(x), num_parallel_calls=tf.data.experimental.AUTOTUNE)
epoch = 10
batch_size = 4096
my_dataset = my_dataset.shuffle(200)
my_dataset = my_dataset.repeat(epoch)
my_dataset = my_dataset.batch(batch_size, drop_reminder=True)
my_dataset = my_dataset.prefetch(batch_size)
After running two epochs, the GPU run out of memory and the jupyter kernel died. I have tried the options at Memory management in Tensorflow's Dataset API Does `tf.data.Dataset.repeat()` buffer the entire dataset in memory? Why would this dataset implementation run out of memory? but, not helpful.
Also, followed the best practices at https://www.tensorflow.org/guide/effective_tf2 https://www.tensorflow.org/guide/data
My pipeline:
map --> shuffle --> repeat --> batch --> prefetch
Did I miss something ? thanks