I encounter a memory leak and decreasing performance when looping over a Keras model predict
function when using a tf.data.Dataset
to feed the model, but not when feeding it with a numpy array.
Does anyone understand what is causing this and/or how to resolve the issue?
Minimal reproducible code snippet (copy/paste runnable):
import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
debug_time = time.time()
while True:
model.predict(x=ds, steps=1)
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()
Result: Predict loop timing starts around 0.04s per iteration, within a minute or two it's up to about 0.5s and process memory continues to increase from a few hundred MB to close to a GB.
Swap out the tf.data.Dataset
for an equivalent numpy array and runtime is ~0.01s consistently.
Working case code snippet (copy/paste runnable):
import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
debug_time = time.time()
while True:
model.predict(x=np_data) # using numpy array directly
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()
Related discussions:
- Memory leak tf.data + Keras - Doesn't seem to address the core issue, but the question appears similar.
- https://github.com/tensorflow/tensorflow/issues/22098 - Possibly an open issue in Keras/Github, but I can't confirm it, changing
inter_op_paralellism
as suggested in that thread has no impact on the results posted here.
Additional info:
- I can reduce the rate of performance degradation by around 10x by passing in an iterator instead of a dataset object. I noticed in
training_utils.py:1314
the Keras code is creating an iterator each call to predict.
TF 1.14.0