I have a tensor of shape (?, 3, 2, 5)
. I want to supply pairs of indices to select from the first and second dimensions of that tensor, that have shape (3, 2)
.
If I supply 4 such pairs, I would expect the resulting shape to be (?, 4, 5)
. I'd thought this is what what batch_gather
is for: to "broadcast" gathering indices over the first (batch) dimension. But this is not what it's doing:
import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))
indices = tf.constant([
[2, 1],
[2, 0],
[1, 1],
[0, 1]
], tf.int32)
tf.batch_gather(data, indices)
Which results in <tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>
instead of the shape that I was expecting.
How can I do what I want without explicitly indexing the batches (which have an unknown size)?