0

Is there any way to iterate through a python list while using control_flow_ops.while_loop?

I am using control_flow_ops.while_loop to iterate through a sequence in my RNN. I also have a list L of functions that I want to call.

Let 'j' represent the current index of the sequence in the while_loop. As I iterate through, I want to call the function at index j in L.

But calling L[j] doesn't work: j is a tensor representing the index of the sequence, and L is a python list. tf.gather also doesn't work because I also can't convert L to a tensor (the elements are functions).

How do I access the elements of L in the while_loop? Example below:

L = [fn1,fn2,fn3]

def body(j,seq):
    seq_elem = seq.read(j)

    # Try to get the function stored in L.
    fn = L[j] # Doesn't work, L is a python list while j is a tensor.

    fn(seq_elem)
    return j+1,seq

_, _ = control_flow_ops.while_loop(
        cond = lambda j, _: j < seq_length,
        body=body,
        loop_vars=(tf.constant(0,dtype=tf.int32), seq)
NTruth
  • 51
  • 3

2 Answers2

0

From @soloice's answer at this link : How to index a list with a TensorFlow tensor?

"Simply run tf.gather(list, tf_look_up[index]), you'll get what you want."

Does that help?

Vib
  • 138
  • 11
  • Unfortunately not. The list L cannot be converted into a tensor, so L cannot be passed to tf.gather. – NTruth Jul 01 '17 at 22:26
0

Lists which cannot be converted to tensors will not work with tf.while_loop (as of july 2017).

Alexandre Passos
  • 5,186
  • 1
  • 14
  • 19