I have a tensor holding a batch of permutations of the integers 0
to time-1
which e.g. has the shape
[batch,time]
Now I want to invert all these permutations to get a tensor of the same shape.
I know this can be done using tf.math.invert_permutation
for a single tensor of shape [time]
, but that function does not support batched input. It will through an error if the input tensor has more than one dimension.
What can I do to make tf.math.invert_permutation
work with batched input?