I followed the example of a threadsafe generator for Keras fit_generator
given here: https://keras.io/utils/#sequence
It looks like the index of the batch (idx
) is locked to each thread. In my case, I want to lock the thread to the example index. Here is my implementation:
class CustomGenerator():
def __init__(self):
self.input_ = np.arange(0, 1000)
self.labels = np.arange(0, 1000) * 0.1
self.batch_sz = 5
self.example_index = 0
def __len__(self):
return np.ceil(len(self.input_) / float(self.batch_sz))
def __getitem__(self, batch_idx):
batch_x = np.zeros(self.batch_sz)
batch_y = np.zeros(self.batch_sz)
row = 0
while row < self.batch_sz:
if self.example_index % 2 == 0:
batch_x[row] = self.input_[self.example_index]
batch_y[row] = self.labels[self.example_index]
row += 1
self.example_index += 1
return batch_x, batch_y
cg = CustomGenerator()
batch_idx = 0
while True:
print(cg.__getitem__(batch_idx))
batch_idx += 1
It prints the right output:
(array([0., 2., 4., 6., 8.]), array([0. , 0.2, 0.4, 0.6, 0.8]))
(array([10., 12., 14., 16., 18.]), array([1. , 1.2, 1.4, 1.6, 1.8]))
(array([20., 22., 24., 26., 28.]), array([2. , 2.2, 2.4, 2.6, 2.8]))
(array([30., 32., 34., 36., 38.]), array([3. , 3.2, 3.4, 3.6, 3.8]))
How can I make sure that this implementation would work in a threadsafe manner, i.e the different workers are not going to use the same example_index
when generating the batches.