My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset
and tf.function
API.
Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function
so this limits the use of packages like numpy
and list indexing.
For example, I would like to do something like this:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
However I cannot index the list_of_funcs
as TypeError: list indices must be integers or slices, not Tensor
. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor
and use tf.gather
.
So my question: how can I reasonably and neatly sample from these functions in a tf.function
?