5

I have a tensor x in pytorch let's say of shape (5,3,2,6) and another tensor idx of shape (5,3,2,1) which contain indices for every element in first tensor. I want a slicing of the first tensor with the indices of the second tensor. I tried x= x[idx] but I get a weird dimensionality when I really want it to be of shape (5,3,2) or (5,3,2,1).

I'll try to give an easier example: Let's say

x=torch.Tensor([[10,20,30],
                 [8,4,43]])
idx = torch.Tensor([[0],
                    [2]])

I want something like

y = x[idx]

such that 'y' outputs [[10],[43]] or something like.

The indices represent the position of the wanted elements the last dimension. for the example above where x.shape = (2,3) the last dimension are the columns, then the indices in 'idx' is the column. I want this but for more than 2 dimensions

Ehsan
  • 12,072
  • 2
  • 20
  • 33
  • How do interpret indices `idx=[[0],[2]]` to get values `[[10],[43]]` from `x`? It is unclear what those indices represent, are they row/column or flattened array indices? – Ehsan Jul 06 '20 at 19:04
  • It would mean the position in the last dimension which for that example is the column. – Jessica Borja Jul 06 '20 at 19:13

3 Answers3

2

From what I understand from the comments, you need idx to be index in the last dimension and each index in idx corresponds to similar index in x (except for the last dimension). In that case (this is the numpy version, you can convert it to torch):

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]

output:

[[10]
 [43]]
Ehsan
  • 12,072
  • 2
  • 20
  • 33
1

You can use range; and squeeze to get proper idx dimension like

x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])

# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
        [43.]])
Dishin H Goyani
  • 7,195
  • 3
  • 26
  • 37
0

Here's the one that works in PyTorch using gather. The idx needs to be in torch.int64 format which the following line will ensure (note the lowercase of 't' in tensor).

idx = torch.tensor([[0],
                    [2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
        [43.]])