1

I want to use BoolTensor indices to slice a multidimensional tensor in Pytorch. I expect for the indexed tensor, the parts where the indices are true are kept, while the parts where the indices are false are sliced out.

My code is like

import torch
a = torch.zeros((5, 50, 5, 50))

tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices

print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)

I expect a[:, tr_indices, :, val_indices] to be of shape [5, 25, 5, 25], however it returns [25, 5, 5]. The result is

torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])

I'm very confused. Can anyone explain why?

iacob
  • 20,084
  • 6
  • 92
  • 119
nanimonai
  • 73
  • 6

1 Answers1

1

PyTorch inherits its advanced indexing behaviour from Numpy. Slicing twice like so should achieve your desired output:

a[:, tr_indices][..., val_indices]
iacob
  • 20,084
  • 6
  • 92
  • 119