0

I'm trying to make a data augmentation that depends on the kind of input I'm receiving, specifically I want to use the filename to decide which kind of augmentation will be used.

For this I'm using Tensorflow's Object Detection API and I'm modifying the augment_input_data function. I've added the following code to the beginning of the function:

my_data_augmentation_options = data_augmentation_options.copy()
  
@tf.function
def filter_supervised(data_aug_options):
  for opt in data_aug_options:
    if opt[0].__name__ in ['random_vertical_flip', 'random_horizontal_flip']:
      data_aug_options.remove(opt)

@tf.function
def filter_unsupervised(data_aug_options):
  for opt in data_aug_options:
    if opt[0].__name__ in ['random_distort_color']:
      data_aug_options.remove(opt)

tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'), filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))

Unfortunately this raises the following error:

ValueError: filter_unsupervised() should not modify its Python input arguments. Check if it modifies any lists or dicts passed as arguments. Modifying a copy is allowed.

Is there a way to modify a python object? Or do I need to search for where it stops executing eagerly and try to maintain the eager execution till this part?

Thanks in advance.

One minimal example that reproduces the error I'm facing would be

def custom_augment_input_data(tensor_dict, data_augmentation_options):
  my_data_augmentation_options = data_augmentation_options.copy()
  @tf.function
  def filter_supervised(data_aug_options):
    for opt in data_aug_options:
      if opt[0].__name__ in ['random_horizontal_flip']:
        data_aug_options.remove(opt)

  @tf.function
  def filter_unsupervised(data_aug_options):
    for opt in data_aug_options:
      if opt[0].__name__ in ['random_distort_color']:
        data_aug_options.remove(opt)
  tf.cond(tf.strings.regex_full_match(tensor_dict['filename'], '\Scrop\S'),
          filter_unsupervised(my_data_augmentation_options), filter_supervised(my_data_augmentation_options))
  # Continue doing stuff below

from object_detection.core.preprocessor import random_horizontal_flip, random_distort_color
data_augmentation_options = [(random_horizontal_flip, {'keypoint_flip_permutation': None, 'probability':0.5}),
                             (random_distort_color, {})]

dummy_dict = {fields.InputDataFields.filename: './crop_32412.jpg'}
custom_augment_input_data(dummy_dict, data_augmentation_options)
  

1 Answers1

0

The error message is actually clear: all we need to do is to make a copy of the object and then we can do whatever we want with it.

In my case the object I want to modify is a dict. I first tried to make a copy using copy.deepcopy, but the object turns out to be a non-serializable tensor while executing in a graph, which can't be copied this way. (The op seems to make a shallow copy data_augmentation_options.copy(), which still points to the same underlying object.)

So my solution is to make a manual copy: dct = {k:v for k,v in dct.items()}, which bypasses the issue. I anticipate it also works for other types of objects (e.g. lst = [e for e in lst]).

mk6
  • 3
  • 2