What is the best way to iterate over the final dimension of an arbitrarily shaped PyTorch tensor?
Say, I have a Tensor like:
import torch
z = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(z.shape)
torch.Size([1, 2, 3])
I would like to iterate over the final dimension, so at each iteration I have (in this case) a [1, 2]
shaped Tensor.
One option I figured out was to permute the Tensor to shift the final dimension to be the first dimension, e.g.,:
for zs in z.permute(-1, *range(0, len(z.shape) - 1)):
print(zs)
tensor([[1, 4]])
tensor([[2, 5]])
tensor([[3, 6]])
But is there a neater or faster or PyTorch-specific way?
This is very similar to this question about NumPy arrays, for which the accepted answer, i.e.,
for i in range(z.shape[-1]):
print(z[..., i])
would also work for PyTorch tensors.