3

I have a tensor t of shape (2, 3, 4)

t = tf.random.normal((2, 3, 4))

<tf.Tensor: id=55, shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.86664855, -0.32786712, -0.9517335 ,  0.989722  ],
        [-0.25011402, -0.35941386, -1.0808105 ,  0.60205466],
        [ 0.07523973, -0.6512919 ,  1.3695312 , -1.5043781 ]],

       [[ 0.33990988, -0.17364176,  0.72955394, -0.7119293 ],
        [ 0.4013214 ,  0.5653289 ,  1.4327284 ,  1.2687784 ],
        [-1.1986154 ,  1.3783301 ,  1.714094  ,  0.49866664]]],
      dtype=float32)>

and a set of indices idx of size (2, 3) with values that index along the last dimension of t

idx = tf.convert_to_tensor(np.random.randint(4, size=(2, 3)))

<tf.Tensor: id=56, shape=(2, 3), dtype=int64, numpy=
array([[2, 2, 3],
       [0, 3, 1]])>

How can I extract the elements of t along it's last dimension at the indices specified by idx? The result should be the following tensor of shape (2, 3).

<tf.Tensor: id=57, shape=(2, 3), dtype=int64, numpy=
array([[-0.9517335, -1.0808105, -1.5043781],
       [0.33990988,  1.2687784,  1.3783301]])>

I have been trying and failing with regular indexing

t[:, :, idx]  # error
t[..., idx]   # error

and tf.gather / tf.gather_nd

tf.gather(t, idx, axis=2)  # has shape (2, 3, 2, 3)
tf.gather_nd(t, idx)       # has shape (2, )

neither of which seem to accomplish this.

Jon Deaton
  • 3,943
  • 6
  • 28
  • 41

3 Answers3

1

Think again about what you are trying to achieve. What are the indices of the elements you are trying to extract for the first and second axes? From your example, it seems that you are thinking about flattening the first two dimension so t is (6,4) and extracting the elements whose first dimension indices are 0:6 and the second dimension indices are given by idx.

To achieve this, you have to actually specify the indices for all dimensions. We can start by reshaping t to be 2D:

t_2d=tf.reshape(t,[-1,tf.shape(t)[-1]])

<tf.Tensor: id=55, shape=(6, 4), dtype=float32, numpy=
array([[-0.86664855, -0.32786712, -0.9517335 ,  0.989722  ],
       [-0.25011402, -0.35941386, -1.0808105 ,  0.60205466],
       [ 0.07523973, -0.6512919 ,  1.3695312 , -1.5043781 ],
       [ 0.33990988, -0.17364176,  0.72955394, -0.7119293 ],
       [ 0.4013214 ,  0.5653289 ,  1.4327284 ,  1.2687784 ],
       [-1.1986154 ,  1.3783301 ,  1.714094  ,  0.49866664]],
      dtype=float32)>

Now, we'll specify the indices of the first axis:

idx_0=tf.reshape(tf.range(t_2d.shape[0]),idx.shape)

<tf.Tensor: id=62, shape=(2, 3), dtype=int32, numpy=
array([[0, 1, 2],
       [3, 4, 5]], dtype=int32)>

Join the indices of the first and second axes as expected by tf.gather_nd:

indices=tf.stack([idx_0,idx],axis=-1)

<tf.Tensor: id=64, shape=(2, 3, 2), dtype=int32, numpy=
array([[[0, 2],
        [1, 2],
        [2, 3]],

       [[3, 0],
        [4, 3],
        [5, 1]]], dtype=int32)>

And finally:

tf.gather_nd(t_2d,indices)

<tf.Tensor: id=66, shape=(2, 3), dtype=float32, numpy=
array([[-0.9517335 , -1.0808105 , -1.5043781 ],
       [ 0.33990988,  1.2687784 ,  1.3783301 ]], dtype=float32)>
Trisoloriansunscreen
  • 1,543
  • 1
  • 15
  • 27
0

Here's a solution:

def tf_select_along_axis(arr, selecting_ixs, axis: int):
    """ Select the given indices along the given axis. 
    :param arr: A N-dimensional tensor of shape (D[0], ..., D[axis], ..., D[N])
    :param selecting_ixs: A N-1 dimensional tensor of shape (D[0], ... D[axis-1], D[axis+1], ... D[N]) int32 which selects elements along axis
    :param axis: The axis along which you're selecting.
    """
    ixs = [tf.broadcast_to(tf.range(arr.shape[d])[(slice(None),) + (None,) * (axis - i)], selecting_ixs.shape) for i, d in enumerate(range(axis))] \
          + [selecting_ixs] \
          + [tf.broadcast_to(tf.range(arr.shape[d])[(slice(None),) + (None,) * (axis - i)], selecting_ixs.shape) for i, d in
             enumerate(range(axis + 1, len(arr.shape)), start=axis + 1)]
    ixs_nd = tf.reshape(tf.stack(ixs, axis=-1), (-1, len(arr.shape)))
    return tf.reshape(tf.gather_nd(arr, ixs_nd), selecting_ixs.shape)

Which can be demonstrated by

def test_select_along_axis():
    arr = tf.random.uniform((20, 30, 40), seed=1234)
    argmax = tf.argmax(arr, axis=1, output_type=tf.int32)
    maxval = tf.reduce_max(arr, axis=1)
    assert np.array_equal(maxval, tf_select_along_axis(arr, argmax, axis=1).numpy())

Peter
  • 12,274
  • 9
  • 71
  • 86
0

As of tensorflow>=2.4 you can take advantage of tf.experimental.numpy.take_along_axis

tf.squeeze(tf.experimental.numpy.take_along_axis(t, idx[..., tf.newaxis], axis=-1))
itamar kanter
  • 1,170
  • 3
  • 10
  • 25