0

What is the best way to iterate over the final dimension of an arbitrarily shaped PyTorch tensor?

Say, I have a Tensor like:

import torch

z = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(z.shape)
torch.Size([1, 2, 3])

I would like to iterate over the final dimension, so at each iteration I have (in this case) a [1, 2] shaped Tensor.

One option I figured out was to permute the Tensor to shift the final dimension to be the first dimension, e.g.,:

for zs in z.permute(-1, *range(0, len(z.shape) - 1)):
    print(zs)

tensor([[1, 4]])
tensor([[2, 5]])
tensor([[3, 6]])

But is there a neater or faster or PyTorch-specific way?

This is very similar to this question about NumPy arrays, for which the accepted answer, i.e.,

for i in range(z.shape[-1]):
    print(z[..., i])

would also work for PyTorch tensors.

Matt Pitkin
  • 3,989
  • 1
  • 18
  • 32

1 Answers1

1

The last approach you provided is clean and efficient. It directly accesses each slice along the final dimension without the need for explicit permutation. It's works well for PyTorch tensors as well.

I think it is the fastest way of iteration for this task

Will
  • 1,619
  • 5
  • 23