2

I am trying to train a neural network model with TF2.3 on GPU (EC2/P2.xlarge with 61GB memory).

My code for data pipeline:

 my_dataset = tf.data.TFRecordDataset(train_files_path, compression_type='GIP') # 500 files, each file size is 500+ KB
 my_dataset = my_dataset.map(lambda x: parse_input(x), num_parallel_calls=tf.data.experimental.AUTOTUNE)
 epoch = 10
 batch_size = 4096
 my_dataset = my_dataset.shuffle(200)
 my_dataset = my_dataset.repeat(epoch)
 my_dataset = my_dataset.batch(batch_size, drop_reminder=True)
 my_dataset = my_dataset.prefetch(batch_size)

After running two epochs, the GPU run out of memory and the jupyter kernel died. I have tried the options at Memory management in Tensorflow's Dataset API Does `tf.data.Dataset.repeat()` buffer the entire dataset in memory? Why would this dataset implementation run out of memory? but, not helpful.

Also, followed the best practices at https://www.tensorflow.org/guide/effective_tf2 https://www.tensorflow.org/guide/data

My pipeline:

map --> shuffle --> repeat --> batch --> prefetch 

Did I miss something ? thanks

user3448011
  • 1,469
  • 1
  • 17
  • 39

1 Answers1

0

I've run into this problem and it can be pretty frustrating---sometimes there is no good solution. Have you tried decreasing your batch size? 4096 is somewhat large; sometimes you can succeed with smaller values. Going down to 256-512 or even lower has helped me.

Another small thing is that I don't believe you need to define an anonymous function with the lambda keyword in map, if the only argument is what's being passed, you can just do

map(parse_input, num_parallel_calls=tf.data.experimental.AUTOTUNE)

You could also try skipping the prefetch call---it takes up extra memory at the cost of speeding up runtime.

vstack17
  • 86
  • 5
  • thanks, I used Tensorflow 1.15 to implement the same model and the training process cost not more than 6GB on the same EC2 instance (cpu and gpu). It seems that I missed some "smart" memory control mechanism in tensorflow 2.3 ? Also, "prefetch" should help improve training time because it can hide the data loading time when gpu or cpu is busy in model training ? – user3448011 Mar 23 '22 at 01:23