I can think of the following trick that can work for you.
Since we have two tensors with different numbers of rows (n and m), first we transform them into the same shape (m x n x 2
) and then subtract. If two rows match, then after subtraction, the entire row will be zero. Then, we need to identify the indices of those rows.
n = a.shape[0] # 3
m = b.shape[0] # 2
_a = a.unsqueeze(0).repeat(m, 1, 1) # m x n x 2
_b = b.unsqueeze(1).repeat(1, n, 1) # m x n x 2
match = (_a - _b).sum(-1) # m x n
indices = (match == 0).nonzero()
if indices.nelement() > 0: # empty tensor check
row_indices = indices[:, 1]
else:
row_indices = []
print(row_indices)
Sample Input/Output
Example 1
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
tensor([0, 2])
Example 2
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 7]])
tensor([2])
Example 3
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 5], [8, 9]])
tensor([0])
Example 4
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 5], [8, 9]])
[]