I just upgraded to tensorflow 2.3. I want to make my own data generator for training. With tensorflow 1.x, I did this:
def get_data_generator(test_flag):
item_list = load_item_list(test_flag)
print('data loaded')
while True:
X = []
Y = []
for _ in range(BATCH_SIZE):
x, y = get_random_augmented_sample(item_list)
X.append(x)
Y.append(y)
yield np.asarray(X), np.asarray(Y)
data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train, validation_data=data_generator_test,
epochs=10000, verbose=2,
use_multiprocessing=True,
workers=8,
validation_steps=100,
steps_per_epoch=500,
)
This code worked fine with tensorflow 1.x. 8 processes were created in the system. The processor and video card were loaded perfectly. "data loaded" was printed 8 times.
With tensorflow 2.3 i got warning:
WARNING: tensorflow: multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
"data loaded" was printed once(should 8 times). GPU is not fully utilized. It also have memory leak every epoch, so traning will stops after several epochs. use_multiprocessing flag did not help.
How to make a generator / iterator in tensorflow(keras) 2.x that can easily be parallelized across multiple CPU processes? Deadlocks and data order are not important.