46

I have a tensor of shape (30, 116, 10), and I want to swap the first two dimensions, so that I have a tensor of shape (116, 30, 10)

I saw that numpy as such a function implemented (np.swapaxes) and I searched for something similar in tensorflow but I found nothing.

Do you have any idea?

nbro
  • 15,395
  • 32
  • 113
  • 196
Alexis Rosuel
  • 563
  • 1
  • 5
  • 12

3 Answers3

59

tf.transpose provides the same functionality as np.swapaxes, although in a more generalized form. In your case, you can do tf.transpose(orig_tensor, [1, 0, 2]) which would be equivalent to np.swapaxes(orig_np_array, 0, 1).

Timo Denk
  • 575
  • 1
  • 6
  • 21
keveman
  • 8,427
  • 1
  • 38
  • 46
  • 22
    What if I don't know the dimensions of my input tensor but I'm sure I want to swap the last 2 axes? Like, what should I do to a tensor variable so that an input of shape `(2, 3, 4, 5)` will end up as `(2, 3, 5, 4)` but the same should work on an input of shape `(3, 4, 5, 6, 7)` (and turn it into `(3, 4, 5, 7, 6)`) – Konstantinos Bairaktaris May 09 '18 at 19:48
  • 1
    @KonstantinosBairaktaris see my answer – joel Dec 07 '20 at 21:37
7

It is possible to use tf.einsum to swap axes if the number of input dimensions is unknown. For example:

  • tf.einsum("ij...->ji...", input) will swap the first two dimensions of input;
  • tf.einsum("...ij->...ji", input) will swap the last two dimensions;
  • tf.einsum("aij...->aji...", input) will swap the second and the third dimension;
  • tf.einsum("ijk...->kij...", input) will permute the first three dimensions;

and so on.

ltskv
  • 71
  • 1
  • 2
6

You can transpose just the last two axes with tf.linalg.matrix_transpose, or more generally, you can swap any number of trailing axes by working out what the leading indices are dynamically, and using relative indices for the axes you want to transpose

x = tf.ones([5, 3, 7, 11])
trailing_axes = [-1, -2]

leading = tf.range(tf.rank(x) - len(trailing_axes))   # [0, 1]
trailing = trailing_axes + tf.rank(x)                 # [3, 2]
new_order = tf.concat([leading, trailing], axis=0)    # [0, 1, 3, 2]
res = tf.transpose(x, new_order)
res.shape                                             # [5, 3, 11, 7]
joel
  • 6,359
  • 2
  • 30
  • 55