I have a tensor x in pytorch let's say of shape (5,3,2,6) and another tensor idx of shape (5,3,2,1) which contain indices for every element in first tensor. I want a slicing of the first tensor with the indices of the second tensor. I tried x= x[idx] but I get a weird dimensionality when I really want it to be of shape (5,3,2) or (5,3,2,1).
I'll try to give an easier example: Let's say
x=torch.Tensor([[10,20,30],
[8,4,43]])
idx = torch.Tensor([[0],
[2]])
I want something like
y = x[idx]
such that 'y' outputs [[10],[43]]
or something like.
The indices represent the position of the wanted elements the last dimension. for the example above where x.shape = (2,3) the last dimension are the columns, then the indices in 'idx' is the column. I want this but for more than 2 dimensions