2

I have a directory with around a million images. I want to create a batch_generator so that I could train my CNN as I cannot hold all those images in memory at once.

So, I wrote a generator function to do so:

def batch_generator(image_paths, batch_size, isTraining):
    while True:
        batch_imgs = []
        batch_labels = []
        
        type_dir = 'train' if isTraining else 'test'
        
        for i in range(len(image_paths)):
            print(i)
            print(os.path.join(data_dir_base, type_dir, image_paths[i]))
            img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
            img  = np.divide(img, 255)
            img = img.reshape(28, 28, 1)
            batch_imgs.append(img)
            label = image_paths[i].split('_')[1].split('.')[0]
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                yield (np.asarray(batch_imgs), np.asarray(batch_labels))
                batch_imgs = []
        if batch_imgs:
            yield batch_imgs

When I call this statement:

index = next(batch_generator(train_dataset, 10, True))

It is printing the same index values and paths hence, it is returning the same batch on every call of next(). How do I fix this?

I used this question as a reference for the code: how to split an iterable in constant-size chunks

Tomerikoo
  • 18,379
  • 16
  • 47
  • 61
  • @kerwei nope, it's correctly indented, it's here to yield the last batch if it's size was < batch_size. It's a very very common "buffering" code pattern. – bruno desthuilliers Jan 18 '19 at 08:06
  • @brunodesthuilliers Yes, I didn't notice the inner if block at first glance. Hence, deleted my comment :) – kerwei Jan 18 '19 at 08:10

4 Answers4

1
# batch generator
def get_batches(dataset, batch_size):
    X, Y = dataset
    n_samples = X.shape[0]

    # Shuffle at the start of epoch
    indices = np.arange(n_samples)
    np.random.shuffle(indices)

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)

        batch_idx = indices[start:end]

        yield X[batch_idx], Y[batch_idx]
David Buck
  • 3,752
  • 35
  • 31
  • 35
0

Generator functions are not generators themselves but "generator factories" - each time you call batch_generator(...) it returns a fresh new generator, ready to start again. IOW, you want:

gen = batch_generator(...)
for batch in gen:       
    do_something_with(batch)

Also:

1/ the way you wrote your generator function will create an infinite generator - the outer while loop will repeat forever - which might or not be what you expected (thought I'd better warn you).

2/ there are two logical errors in your code: first, you don't reset the batch_labels list, then on the last yield you only yield batch_imgs, which is not consistant with the inner yield. FWIW, instead of maintaining two lists (one for the images and the other for the labels), you'd perhaps be better using one single list of (img, label) tuples.

And as a final side note: you don't need to use range(len(lst)) to iterate on a list - Python's for loop is of the foreach type, it directly iterates over the iterable's items, ie:

for path image_paths:
    print(path)

works just the same, is more readable, and is a bit faster...

bruno desthuilliers
  • 75,974
  • 6
  • 88
  • 118
  • About the outer loop, I am going to use the generator in keras to train a CNN. So, the course I did on that used a similar implementation of the batch generator. Can you explain more about the drawbacks or benefits of infinite generator? –  Jan 18 '19 at 10:28
  • What you're going to use to "train a CNN" are (directly or indirectly) the results of iterating over your generator, not the generator itself. And the principle of an infinite generator is that the iteration never stops - `next(iterator)` will _always_ return something and a `for item in iterator` loop will run forever. It's impossible to tell if an infinite generator is appropriate for your own use case without seeing exactly how it's used, I just thought you might want to be warned about this since you don't really seem to fully grasp what generators are and how they work. – bruno desthuilliers Jan 18 '19 at 11:14
  • Yeah, I don't have a complete understanding of generators. This is my first time. But I have got it to work. Thanks for the help. –  Jan 19 '19 at 08:17
0

It looks to me like you're trying to achieve something along this line:

def batch_generator(image_paths, batch_size, isTraining):
    your_code_here

Calling the generator - instead of what you have:

index = next(batch_generator(train_dataset, 10, True))

You can try:

index = iter(batch_generator(train_dataset, 10, True))
index.__next__()
kerwei
  • 1,822
  • 1
  • 13
  • 22
  • 1
    1/ you don't need to call `iter()` on an iterable (in this case it will actually just return it's argument unchanged), 2/ `__next__()` is a "magic method" (the implementation of a generic operator or operator-like function) and should not be called directly but thru the `next()` function. – bruno desthuilliers Jan 18 '19 at 08:06
  • @brunodesthuilliers Thanks for the pointer! Admittedly, I'm still rather new to generators. Getting into these discussions helps me to learn and improve. – kerwei Jan 18 '19 at 08:08
0

I made my own generator that supports both limit, batches or simply step 1 iteration:

def gen(batch = None, limit = None):
    ret = []
    for i in range(1, 11): # put your data reading here and i counter (i += 1) under for
        if batch:
            ret.append(i)
            if limit and i == limit:
                if len(ret):            
                    yield ret
                return
            if len(ret) == batch:
                yield ret
                ret = []
        else:
            if limit and i > limit:
                break
            yield i
    if batch and len(ret): # yield the rest of the list
        yield ret
            
g = gen(batch=5, limit=8) # batches with limit
#g = gen(batch=5) # batches
#g = gen(limit=5) # step 1 with limit
#g = gen() # step 1 with limit
for i in g:
    print(i)
Tomerikoo
  • 18,379
  • 16
  • 47
  • 61
luky
  • 2,263
  • 3
  • 22
  • 40