1

I have a generator that looks like this:

def data_generator(data_file, index_list,....):
      orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
                x_list = list()
                y_list = list()

My dataset size is 55GB and stored as a .h5 file (data.h5). It is extremely slow when reading the data. It takes 7000s for one epoch and I get a segmentation fault after like 6 epochs.

I thought if I set multi_processing = False and workers > 1 it will speed up reading data:

model.fit(multi_processing = False, workers = 8)

But when I do that I get the following error:

RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.

Is there a way to make my generator thread-safe? Or is there any other efficient way to generate this data?

Kaveh
  • 4,618
  • 2
  • 20
  • 33
Dushi Fdz
  • 161
  • 3
  • 22
  • Does this answer your question? [Are Generators Threadsafe?](https://stackoverflow.com/questions/1131430/are-generators-threadsafe). See in particular the answer that features the `LockedIterator` class (second answer). Actually, I think that `LockedIterator` class is wrong. – Booboo Aug 09 '21 at 19:59
  • No, it doesn't. I tried other solutions posted but nothing worked. My question is how I can make the above generator thread-safe so I can set `use_multiprocessing=False, workers > 1` and check if there is any improvement in the speed of the data loading process. My ultimate objective is to make training faster, so if someone knows any other efficient way to load data that would be even better. – Dushi Fdz Aug 09 '21 at 20:07
  • Oops. I copied and pasted incorrectly. my answer below. Take a look and let me know if you follow that. Also, if you have a question about efficiency, *that* is another future post. Do not piggy-back questions like that. Posts get closed when they ask more than one question. – Booboo Aug 09 '21 at 20:13

1 Answers1

4

I believe the LockedIterator class I referenced in my comment above is incorrect and should be as coded in the example below:

import threading

class LockedIterator(object):
    def __init__(self, it):
        self.lock = threading.Lock()
        self.it = iter(it)

    def __iter__(self): return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()
            
def gen():
    for x in range(10):
        yield x

new_gen = LockedIterator(gen())

def worker(g):
    for x in g:
        print(x, flush=True)

t1 = threading.Thread(target=worker, args=(new_gen,))
t2 = threading.Thread(target=worker, args=(new_gen,))
t1.start()
t2.start()
t1.join()
t2.join()

Prints:

0
1
23

4
5
6
7
8
9

If you want to guarantee that the printed output prints one value per line, then we would also need to pass a threading.Lock instance to each thread and issue the print statement under control of that lock so printing is serialized.

Booboo
  • 38,656
  • 3
  • 37
  • 60
  • I don't understand how I can adapt this to the generator I have posted in my question. Can you please explain a little bit more? – Dushi Fdz Aug 09 '21 at 20:13
  • Where you currently have data_generator(*actual parameters*) use instead LockedIterator(data_generator(*actual parameters*)). Of course, it would have helped had you published how this generator was being referenced. Unfortunately, I am not familiar with Keras to guess. – Booboo Aug 09 '21 at 20:27
  • I asked in another question about efficiency. There I have provided more details about how this generator was being referenced: https://stackoverflow.com/questions/68705944/reading-h5-file-is-extremely-slow?noredirect=1#comment121421402_68705944 – Dushi Fdz Aug 09 '21 at 20:30
  • So you would want for *this* question `training_generator = LockedIterator(data_generator(data_file, training_list,....))` and you would pass `training_generator` to each thread. – Booboo Aug 09 '21 at 20:38
  • I edited the way you suggested: `training_generator = LockedIterator(data_generator(data_file, training_list,....))`. And then `model.fit(use_multiprocessing=False, workers =8)`. It didn't give me an error yet. Waiting for th efirst epoch to complete to check if there is any improvement in the speed. – Dushi Fdz Aug 09 '21 at 20:42
  • Whether it makes it faster or not, don't forget that the question was "How to make a generator thread-safe?" – Booboo Aug 09 '21 at 20:44
  • Sure. I will accept the answer once the first epoch completes with no errors. If you can, please have a look at the other question I referenced about efficiency. – Dushi Fdz Aug 09 '21 at 20:46
  • I did take a look. Unfortunately, I am not familiar with either the .h5 file format or `pytables`. Sorry. – Booboo Aug 09 '21 at 21:50