0

I would like to use an ipython button to run a function that trains a deep learning model using Keras's fit.generator() and ImageDataGenerator(). I tried to use lambda to pass the arguments to the function, but it returns TypeError: expected str, bytes or os.PathLike object, not Button.

Code:

def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale",
                    mask_color_mode = "grayscale",image_save_prefix  = "image",mask_save_prefix  = "mask",
                    flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        seed = seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes = [mask_folder],
        class_mode = None,
        color_mode = mask_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = mask_save_prefix,
        seed = seed)
    train_generator = zip(image_generator, mask_generator)
    for (img,mask) in train_generator:
        img,mask = adjustData(img,mask,flag_multi_class,num_class)
        yield (img,mask)

def segmentation_training(trainfolder, modelname):
    data_gen_args = dict(rotation_range=0.1,
                        width_shift_range=[0.0, 0, 0.5],
                        height_shift_range=[0.0, 0, 0.5],
                        zoom_range=[0.5,1],
                        horizontal_flip=True,
                        fill_mode='nearest')   
    myGene = trainGenerator(2,trainfolder,'image','label',data_gen_args,save_to_dir = None)
    model = unet()
    model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
    model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])

modelname = "test"
trainfolder = Path('Data/Segmentation/dataset/train')
btn = widgets.Button(description="Run")
btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
display(btn)

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-41-d4282548b872> in <lambda>(trainfolder, modelname)
     46 trainfolder = Path('Data/Segmentation/dataset/train')
     47 btn = widgets.Button(description="Run")
---> 48 btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
     49 display(btn)

<ipython-input-41-d4282548b872> in segmentation_training(trainfolder, modelname)
     40     model = unet()
     41     model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
---> 42     model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])
     43 
     44 

~/virtualenv/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/virtualenv/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1413             use_multiprocessing=use_multiprocessing,
   1414             shuffle=shuffle,
-> 1415             initial_epoch=initial_epoch)
   1416 
   1417     @interfaces.legacy_generator_methods_support

~/virtualenv/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    175             batch_index = 0
    176             while steps_done < steps_per_epoch:
--> 177                 generator_output = next(output_generator)
    178 
    179                 if not hasattr(generator_output, '__len__'):

~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

~/virtualenv/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-41-d4282548b872> in trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode, mask_color_mode, image_save_prefix, mask_save_prefix, flag_multi_class, num_class, save_to_dir, target_size, seed)
     13         save_to_dir = save_to_dir,
     14         save_prefix  = image_save_prefix,
---> 15         seed = seed)
     16     mask_generator = mask_datagen.flow_from_directory(
     17         train_path,

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in flow_from_directory(self, directory, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
    962             follow_links=follow_links,
    963             subset=subset,
--> 964             interpolation=interpolation)
    965 
    966     def standardize(self, x):

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in __init__(self, directory, image_data_generator, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, data_format, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
   1731         self.samples = sum(pool.map(function_partial,
   1732                                     (os.path.join(directory, subdir)
-> 1733                                      for subdir in classes)))
   1734 
   1735         print('Found %d images belonging to %d classes.' %

/usr/lib/python3.6/multiprocessing/pool.py in map(self, func, iterable, chunksize)
    264         in a list that is returned.
    265         '''
--> 266         return self._map_async(func, iterable, mapstar, chunksize).get()
    267 
    268     def starmap(self, func, iterable, chunksize=None):

/usr/lib/python3.6/multiprocessing/pool.py in _map_async(self, func, iterable, mapper, chunksize, callback, error_callback)
    374             raise ValueError("Pool not running")
    375         if not hasattr(iterable, '__len__'):
--> 376             iterable = list(iterable)
    377 
    378         if chunksize is None:

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in <genexpr>(.0)
   1731         self.samples = sum(pool.map(function_partial,
   1732                                     (os.path.join(directory, subdir)
-> 1733                                      for subdir in classes)))
   1734 
   1735         print('Found %d images belonging to %d classes.' %

/usr/lib/python3.6/posixpath.py in join(a, *p)
     78     will be discarded.  An empty last part will result in a path that
     79     ends with a separator."""
---> 80     a = os.fspath(a)
     81     sep = _get_sep(a)
     82     path = a

TypeError: expected str, bytes or os.PathLike object, not Button

When I run segmentation_train(trainpath,modelname) without the button implementation, it works fine. How can I call the function by pressing the button? Thanks in advance

1 Answers1

0

Your lambda is bound to the Button class it was passed into, which implicitly made the first parameter the Button object itself. The result was that the trainpath parameter, was actually a renamed btn instance of Button. The functions that were trying to use trainpath as a filepath string were confused and so threw the error.

If you want to keep the lambda, simply add self as the first parameter, and then ignore it:

btn.on_click(lambda self, trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))

Otherwise, there is another suggested implementation using functools and calling a function with explicit parameters:

import functools 

def click_func(trainfolder,modelname):
    segmentation_training(trainfolder,modelname)

btn.on_click(functools.partial(click_func,trainfolder=trainfolder,modelname=modelname))
Jon
  • 1,820
  • 2
  • 19
  • 43
  • @MaurícioPires, glad to hear it! If you think the answer is correct and helpful you are welcome to mark it as accepted. – Jon Feb 07 '20 at 15:44