0

I have two tensors and would like to check if elements of row in a, are in the same row in b

a = [[1,2,3], [7,8,4]]
b = [[2,1,1], [4,5,6]]
c = [[T,T,F], [F,F,T]]

I would like this to be done in pure Pytorch in the fastest way possible.

jan biel
  • 35
  • 6

1 Answers1

0

Found the solution (https://stackoverflow.com/a/67870684/12216433)

In my case it would work like this.

AA = a.reshape(2, 3, 1)
BB = b.reshape(2, 1, 3)
mask = (AA == BB).sum(-1).bool()
jan biel
  • 35
  • 6