0

I have a tensor holding a batch of permutations of the integers 0 to time-1 which e.g. has the shape

[batch,time]

Now I want to invert all these permutations to get a tensor of the same shape. I know this can be done using tf.math.invert_permutation for a single tensor of shape [time], but that function does not support batched input. It will through an error if the input tensor has more than one dimension.

What can I do to make tf.math.invert_permutation work with batched input?

Frithjof
  • 2,214
  • 1
  • 17
  • 38

1 Answers1

1

One way to do it to use the tf.data.Dataset API.

Following is how I will implement it.

enter image description here

pratsbhatt
  • 1,498
  • 10
  • 20
  • Thanks for the answer. Instead of using the Dataset API, I can also just use `tf.map_fn`. Good idea actually! I am still a little concerned about performance, but its pretty simple to implement and get's me going for now. I am using: `tf.reshape(tf.map_fn(fn=lambda x: tf.math.invert_permutation(x), elems=tf.reshape(x, [-1, x.shape[len(x.shape)-1]])), x.shape)` – Frithjof Dec 06 '20 at 10:57
  • I am actually not sure exactly about the performance numbers, but I would say that the prefered way to use the input data pipeline is using `tf.data.Dataset` API as per Tensorflow. – pratsbhatt Dec 06 '20 at 10:59
  • I did some tests with this, and in my case it was horribly slow to use `map` here. Instead, I ended up using a `tf.range` in combination with `tf.scatter_nd`, and with this implementing the simple numpy solution at the end of this answer: https://stackoverflow.com/a/25535723/2766231. – Frithjof Feb 20 '21 at 12:47