I'm trying to make a data augmentation that depends on the kind of input I'm receiving, specifically I want to use the filename to decide which kind of augmentation will be used.
For this I'm using Tensorflow's Object Detection API and I'm modifying the augment_input_data function. I've added the following code to the beginning of the function:
my_data_augmentation_options = data_augmentation_options.copy()
@tf.function
def filter_supervised(data_aug_options):
for opt in data_aug_options:
if opt[0].__name__ in ['random_vertical_flip', 'random_horizontal_flip']:
data_aug_options.remove(opt)
@tf.function
def filter_unsupervised(data_aug_options):
for opt in data_aug_options:
if opt[0].__name__ in ['random_distort_color']:
data_aug_options.remove(opt)
tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'), filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))
Unfortunately this raises the following error:
ValueError: filter_unsupervised() should not modify its Python input arguments. Check if it modifies any lists or dicts passed as arguments. Modifying a copy is allowed.
Is there a way to modify a python object? Or do I need to search for where it stops executing eagerly and try to maintain the eager execution till this part?
Thanks in advance.
One minimal example that reproduces the error I'm facing would be
def custom_augment_input_data(tensor_dict, data_augmentation_options):
my_data_augmentation_options = data_augmentation_options.copy()
@tf.function
def filter_supervised(data_aug_options):
for opt in data_aug_options:
if opt[0].__name__ in ['random_horizontal_flip']:
data_aug_options.remove(opt)
@tf.function
def filter_unsupervised(data_aug_options):
for opt in data_aug_options:
if opt[0].__name__ in ['random_distort_color']:
data_aug_options.remove(opt)
tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'),
filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))
# Continue doing stuff below
from object_detection.core.preprocessor import random_horizontal_flip, random_distort_color
data_augmentation_options = [(random_horizontal_flip, {'keypoint_flip_permutation': None, 'probability':0.5}),
(random_distort_color, {})]
dummy_dict = {fields.InputDataFields.filename: './crop_32412.jpg'}
custom_augment_input_data(dummy_dict, data_augmentation_options)