1

I followed the example of a threadsafe generator for Keras fit_generator given here: https://keras.io/utils/#sequence It looks like the index of the batch (idx) is locked to each thread. In my case, I want to lock the thread to the example index. Here is my implementation:

class CustomGenerator():

    def __init__(self):
        self.input_ = np.arange(0, 1000)
        self.labels = np.arange(0, 1000) * 0.1
        self.batch_sz = 5
        self.example_index = 0

    def __len__(self):
        return np.ceil(len(self.input_) / float(self.batch_sz))

    def __getitem__(self, batch_idx):
        batch_x = np.zeros(self.batch_sz)
        batch_y = np.zeros(self.batch_sz)
        row = 0
        while row < self.batch_sz:
            if self.example_index % 2 == 0:
                batch_x[row] = self.input_[self.example_index]
                batch_y[row] = self.labels[self.example_index]
                row += 1
            self.example_index += 1

        return batch_x, batch_y

cg = CustomGenerator()
batch_idx = 0

while True:
    print(cg.__getitem__(batch_idx))
    batch_idx += 1

It prints the right output:

(array([0., 2., 4., 6., 8.]), array([0. , 0.2, 0.4, 0.6, 0.8]))
(array([10., 12., 14., 16., 18.]), array([1. , 1.2, 1.4, 1.6, 1.8]))
(array([20., 22., 24., 26., 28.]), array([2. , 2.2, 2.4, 2.6, 2.8]))
(array([30., 32., 34., 36., 38.]), array([3. , 3.2, 3.4, 3.6, 3.8]))

How can I make sure that this implementation would work in a threadsafe manner, i.e the different workers are not going to use the same example_index when generating the batches.

JMarc
  • 984
  • 1
  • 13
  • 21
  • A little late, but maybe this helps you: https://stackoverflow.com/questions/56441216/on-fit-generator-and-thread-safety – Markus Jun 13 '19 at 12:46

0 Answers0