0

My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset and tf.function API.

Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function so this limits the use of packages like numpy and list indexing.

For example, I would like to do something like this:

import tensorflow as tf

@tf.function
def func1(tensor):
    # Apply some rotation here
    ...

@tf.function
def func2(tensor):
    ...

...

@tf.function
def func24(tensor):
    ...


@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    
    return list_of_funcs[a](tensor)

However I cannot index the list_of_funcs as TypeError: list indices must be integers or slices, not Tensor. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor and use tf.gather.

So my question: how can I reasonably and neatly sample from these functions in a tf.function?

Matt Lyon
  • 113
  • 9
  • 3
    I would rather consider whether there is a better way than to define 24 separate functions for rotation... are they really so different that you cannot have one function with different parameterizations? – xdurch0 Oct 25 '21 at 16:18
  • That's a good point, you certainly could define one function for all 24, however i fear that function would essentially end up being a bunch of if statements in this case. – Matt Lyon Oct 26 '21 at 10:07

3 Answers3

2

You can use a bunch of nested tf.cond. If a condition is met, it will call either the true_fn or the false_fn. Since you have more than two functions, you can nest them for as many functions as you like. For instance, I'm making functions that multiply the input by either 2, 3, 4, or 5, depending on the value of a random variable.

import tensorflow as tf

x = 10

@tf.function
def mult_2():
    tf.print(f'i was 2, returning {x} multiplied by 2')
    return tf.multiply(x, 2)

@tf.function
def mult_3():
    tf.print(f'i was 3, returning {x} multiplied by 3')
    return tf.multiply(x, 3)


@tf.function
def mult_4():
    tf.print(f'i was 4, returning {x} multiplied by 4')
    return tf.multiply(x, 4)


@tf.function
def mult_5():
    tf.print(f'i was 5, returning {x} multiplied by 5')
    return tf.multiply(x, 5)


i = tf.random.uniform((), 1, 5, dtype=tf.int32)

tf.cond(i == 2, mult_2,
        lambda: tf.cond(i == 3, mult_3,
                        lambda: tf.cond(i == 4, mult_4, mult_5)))
I was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>

Note that mult_5 will execute if none of the conditions are met.

Innat
  • 16,113
  • 6
  • 53
  • 101
Nicolas Gervais
  • 33,817
  • 13
  • 115
  • 143
  • Is this effectively the same as using `if` statements within the main `tf.function` `apply`? As I understand it, the python control flow code gets converted to `tf.cond` and `tf.while` within a `tf.function` – Matt Lyon Oct 26 '21 at 10:00
  • what do you mean, "the same"? i think most of the questions you have are covered in the documentation i linked – Nicolas Gervais Oct 26 '21 at 13:17
  • 1
    By the same i mean that python code using `if` and `else` within a function decorated with `tf.function` would get converted into the tensorflow control flow ops e.g. `tf.cond` when executed in the graph. – Matt Lyon Oct 27 '21 at 13:25
1

Maybe try using tf.py_function, which:

Wraps a python function into a TensorFlow op that executes it eagerly.

For example (tested on Google Colab):

import tensorflow as tf
import random

@tf.function
def func1(tensor):
    print('func1')
    return tensor

@tf.function
def func2(tensor):
    print('func2')
    return tensor

@tf.function
def func3(tensor):
    print('func3')
    return tensor

@tf.function
def func4(tensor):
    print('func4')
    return tensor

@tf.function
def apply(tensor):
    dispatcher = {
        'func1': func1,
        'func2': func2,
        'func3': func3,
        'func4': func4
    }
    keys = list(dispatcher)
    
    def get_random_function_and_apply(t):
      return dispatcher[random.choice(keys)](t)

    y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
                       
    return y
    
mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
with mirrored_strategy.scope():
   output = apply(tf.random.normal((5, 5, 5)))
   print(output)

'''
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
func4
tf.Tensor(
[[[ 0.6041213  -2.054427    1.1755397  -0.62914884 -0.00978021]
  [ 0.06134182 -1.5529596  -0.3429052  -0.03199977 -1.1796658 ]
  [-0.65084136 -1.5009187  -0.43266404 -0.18494445  1.2958355 ]
  [-1.6614605  -0.7398612   1.5384725  -0.24926051 -0.5075399 ]
  [ 0.7781286  -0.4102168   1.2152135   0.4508075  -1.7295381 ]]

 [[-1.0509509  -1.271087    1.9061071   0.61855525  0.58581835]
  [ 2.080663    0.43406835  0.32372198 -0.71427256  0.04448809]
  [-0.6438594  -1.1245041  -0.4723388  -0.8302859  -2.0056007 ]
  [ 1.1778332   0.2977344   0.7516829   1.1387901  -0.71768486]
  [-0.44642782 -0.6523012  -0.48157197 -0.8197472   0.3635474 ]]

 [[-0.43357274  1.166849   -0.04528571  0.44322303  0.74193203]
  [ 1.2332342   0.07857647  1.3399298   0.62153     1.835202  ]
  [ 0.48021084  0.36239776  0.16630112  0.59010863  1.8134127 ]
  [-1.1444335   1.2445287  -1.2320557   0.08095992 -0.1379302 ]
  [-1.101756   -1.8099649   0.18504284  0.15212883  0.33380997]]

 [[-0.68228734 -0.82357454 -0.744171   -0.04959428 -1.3200126 ]
  [ 0.813062    1.0669035  -0.7924809  -0.0548021   0.8043163 ]
  [ 1.6480085  -0.17134379  0.25517386  0.02731211  1.2226027 ]
  [-1.9785942  -0.22399756 -0.6814836   1.2065881  -1.7922156 ]
  [-0.34833568 -1.0567352   1.5795225   0.14899854  0.5924402 ]]

 [[-1.057639   -1.1659449  -0.22045298  0.39324322 -1.3500952 ]
  [-0.32044935  0.9534627   0.40809664 -1.0296333  -0.8129102 ]
  [-0.13515176 -0.32676768 -0.9333701   0.35130095 -1.5411847 ]
  [ 2.090785    0.3497966   0.27694222  0.78199005 -0.08591356]
  [ 0.9621986  -2.3930101  -1.1035724   0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)

'''
AloneTogether
  • 25,814
  • 5
  • 20
  • 39
  • Yes, I could go down that route but I wanted to avoid using `tf.py_function` because of its limitations, particularly in distributed training. – Matt Lyon Oct 25 '21 at 15:57
  • Hmm, what limitations are you exactly talking about? – AloneTogether Oct 25 '21 at 16:20
  • https://www.tensorflow.org/api_docs/python/tf/py_function says at the bottom with regards to distributed training "The operation must run in the same address space as the Python program that calls tf.py_function(). If you are using distributed TensorFlow, you must run a tf.distribute.Server in the same process as the program that calls tf.py_function() and you must pin the created operation to a device in that server". It's not clear to me how to ensure the distributed training would work as intended from this – Matt Lyon Oct 26 '21 at 10:13
1

You can use tf.switch_case like

def func1(tensor):
    return tensor * 1

def func2(tensor):
    return tensor * 2

def func24(tensor):
    return tensor * 24

class Lambda:
    def __init__(self, func, arg):
        self._func = func
        self._arg = arg
        
    def __call__(self):
        return self._func(self._arg)

@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, func24]

    branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
    output = tf.switch_case(
        branch_index=branch_index, 
        branch_fns=[Lambda(func, tensor) for func in list_of_funcs], 
    )
    
    return output

Decorator @tf.function is needed only for entire function you wish to optimize that is apply in this case. If you use apply inside tf.data.Dataset.map the decorator is not needed at all.

See this discussion to understand why we have to define class Lambda here.

Alexey Tochin
  • 653
  • 5
  • 8