2

a and b are torch tensor No repeating elements a shape is [n,2] like:

[[1,2]
[2,3]
[4,6]
...]

b is[m,2] like:

[[1,2]
[4,6]
....
]

how to get the index of b in a, example:

a = [[1,2]
[2,4]
[6,7]
]
b = [[1,2]
[6,7]]

the index should be (0,3), we can use gpu,

unsky
  • 51
  • 6

2 Answers2

2

I can think of the following trick that can work for you.

Since we have two tensors with different numbers of rows (n and m), first we transform them into the same shape (m x n x 2) and then subtract. If two rows match, then after subtraction, the entire row will be zero. Then, we need to identify the indices of those rows.

n = a.shape[0] # 3
m = b.shape[0] # 2
_a = a.unsqueeze(0).repeat(m, 1, 1) # m x n x 2
_b = b.unsqueeze(1).repeat(1, n, 1) # m x n x 2

match = (_a - _b).sum(-1) # m x n
indices = (match == 0).nonzero()
if indices.nelement() > 0: # empty tensor check
    row_indices = indices[:, 1]
else:
    row_indices = []

print(row_indices)

Sample Input/Output

Example 1

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
tensor([0, 2])

Example 2

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 7]])
tensor([2])

Example 3

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 5], [8, 9]])
tensor([0])

Example 4

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 5], [8, 9]])
[]
Wasi Ahmad
  • 35,739
  • 32
  • 114
  • 161
1

Here @jpp 's, numpy solution is almost your answer after this

You just need to get indices using nonzero and flatten tensor using flatten to get expected shape.

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
(a[:, None] == b).all(-1).any(-1).nonzero().flatten()
tensor([0, 2])
Dishin H Goyani
  • 7,195
  • 3
  • 26
  • 37