Is there any way to iterate through a python list while using control_flow_ops.while_loop?
I am using control_flow_ops.while_loop to iterate through a sequence in my RNN. I also have a list L of functions that I want to call.
Let 'j' represent the current index of the sequence in the while_loop. As I iterate through, I want to call the function at index j in L.
But calling L[j] doesn't work: j is a tensor representing the index of the sequence, and L is a python list. tf.gather also doesn't work because I also can't convert L to a tensor (the elements are functions).
How do I access the elements of L in the while_loop? Example below:
L = [fn1,fn2,fn3]
def body(j,seq):
seq_elem = seq.read(j)
# Try to get the function stored in L.
fn = L[j] # Doesn't work, L is a python list while j is a tensor.
fn(seq_elem)
return j+1,seq
_, _ = control_flow_ops.while_loop(
cond = lambda j, _: j < seq_length,
body=body,
loop_vars=(tf.constant(0,dtype=tf.int32), seq)