2

I have two 2D tensors, in different length, both are different subsets of the same original 2d tensor and I would like to find all the matching "rows"
e.g

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

I've only see numpy solutions, that use dtype as dicts, and does not work for pytorch.


Here is how I do it in numpy

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)
DsCpp
  • 2,259
  • 3
  • 18
  • 46

2 Answers2

5

This answer was posted before the OP updated the question with other restrictions that changed the problem quite a bit.

TL;DR You can do something like this:

torch.where((A == B).all(dim=1))[0]

First, assuming you have:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

We can check that A == B returns:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

So, what we want is: the rows in which they are all True. For that, we can use the .all() operation and specify the dimension of interest, in our case 1:

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

What you actually want to know is where the Trues are. For that, we can get the first output of the torch.where() function:

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])
Berriel
  • 12,659
  • 4
  • 43
  • 67
  • Thank you very much, I meant they may have a different length, and the order may be different. – DsCpp Jan 13 '20 at 06:00
  • @DsCpp Oh, I see. Then it is a completely different problem. Post the NumPy solutions that you have found. It will be easier to understand all your restrictions. – Berriel Jan 13 '20 at 11:15
2

If A and B are 2D tensors, the following code finds the indices such that A[indices] == B. If multiple indices satisfy this condition, the first index found is returned. If not all elements of B are present in A, the corresponding index is ignored.

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])
okh
  • 470
  • 4
  • 9