I have a tensor t
of shape (2, 3, 4)
t = tf.random.normal((2, 3, 4))
<tf.Tensor: id=55, shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.86664855, -0.32786712, -0.9517335 , 0.989722 ],
[-0.25011402, -0.35941386, -1.0808105 , 0.60205466],
[ 0.07523973, -0.6512919 , 1.3695312 , -1.5043781 ]],
[[ 0.33990988, -0.17364176, 0.72955394, -0.7119293 ],
[ 0.4013214 , 0.5653289 , 1.4327284 , 1.2687784 ],
[-1.1986154 , 1.3783301 , 1.714094 , 0.49866664]]],
dtype=float32)>
and a set of indices idx
of size (2, 3)
with values that index along the last dimension of t
idx = tf.convert_to_tensor(np.random.randint(4, size=(2, 3)))
<tf.Tensor: id=56, shape=(2, 3), dtype=int64, numpy=
array([[2, 2, 3],
[0, 3, 1]])>
How can I extract the elements of t
along it's last dimension at the indices specified by idx
? The result should be the following tensor of shape (2, 3)
.
<tf.Tensor: id=57, shape=(2, 3), dtype=int64, numpy=
array([[-0.9517335, -1.0808105, -1.5043781],
[0.33990988, 1.2687784, 1.3783301]])>
I have been trying and failing with regular indexing
t[:, :, idx] # error
t[..., idx] # error
and tf.gather
/ tf.gather_nd
tf.gather(t, idx, axis=2) # has shape (2, 3, 2, 3)
tf.gather_nd(t, idx) # has shape (2, )
neither of which seem to accomplish this.