I have data in large files with a custom format. At the moment I am creating a tf.data.Dataset
with the name of those files, and then calling a tf.py_function
to access them as needed for my training. The py_function has to load the complete file (nearly 8GB) into memory in order to build an array of just a few MB (1024x1024x4). The py_function only returns that array and the corresponding label. The problem is that each sample that I load increases my CPU RAM usage by nearly 8GB. Very quickly, my computer runs out of RAM and the program crashes. When I run my program outside VS Code, I get through twice as many batches as when I use the debugger, but it's still 13 batches max. (I have 32GB of CPU RAM, and 13*8 > 32, so it looks like the memory gets freed sometimes, but maybe not fast enough?)
I keep the batch_size and prefetch both small so that only a few of these large arrays need to be in memory at the same time. I expected that tensorflow would free up that memory once the py_function exits and it is out of scope. I tried to encourage memory to be freed earlier by explicitly deleting the variable and calling the garbage collector, but that didn't help.
I don't think I can create a minimum working example because the data format and the methods to load the data are custom, but here are the relevant parts of my code:
import pickle
import gc
import tensorflow as tf
def tf_load_raw_data_and_labels(raw_data_files, label_files):
[raw_data, labels] = tf.py_function(load_raw_data_and_labels, [raw_data_files, label_files], [tf.float32, tf.float32])
raw_data.set_shape((1024, 1024, 4))
labels.set_shape((1024))
return raw_data, labels
def load_raw_data_and_labels(raw_data_file, label_file):
#load 8GB big_datacube, extract what I need into raw_data
del big_datacube
gc.collect() #no noticeable difference
return raw_data, labels
with open("train_tiles_list.pickle", "rb") as fid:
raw_data_files, label_files = pickle.load(fid)
train_dataset = tf.data.Dataset.from_tensor_slices((raw_data_files, label_files))
train_dataset = train_dataset.shuffle(n_train)#.repeat()
train_dataset = train_dataset.map(tf_load_raw_data_and_labels)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(1)
I usually train a ResNet50 model using the model.fit()
function from a tf.keras.Model
, but I also tried the setup from the tf2 quickstart tutorial, which lets me set a debug point after each batch has trained. At the debug point I checked the memory usage of active variables. The list is very short, and there are no variables over 600KiB. At this point gc.collect()
returns a number between 10 and 20, even after running it a few times, but I'm not too sure what that means.
It might end up being easiest to crunch through all the large files and save the smaller arrays to their own files, before I start any training. But for now, I'd like to understand if there is something fundamental causing the memory to not be freed. Is it a memory leak? Perhaps related to tf.data.Datasets, py_functions, or something else specific to my setup?
Edit: I have read that python's garbage collection was updated with python3.4. Because of a dependency related to my custom data, I am using python2.7. Could that be part of the problem?
Edit 2: I found some github issues about memory leaks when using tensorflow. The proposed workaround (tf.random.set_seed(1)
) doesn't work for me:
https://github.com/tensorflow/tensorflow/issues/31253
https://github.com/tensorflow/tensorflow/issues/19671