Unless I've messed something up, an element-wise 'in' check which treats the rows or subtensors as elements can be done like this:
(b[:,None]==a).all(dim=-1).any(dim=0)
b[:,None]
adds dimension to each "row" in 'b' such that it can be broadcast to be compared with each "row" of 'a' in the usual way by element. This provides 2 sub-tensors in the 0th dimension the same size of 'b' where the first sub-tensor is comparing b[0,0]
, b[1,0]
, and b[2,0]
with a[0,0]
and comparing b[0,1]
, b[1,1]
, and b[2,1]
with a[0,1]
, and the second sub-tensor is similarly comparing b
with a[1,0]
and a[1,1]
.
So, in the last dimension, any sub-tensor of all True
will be one where each of a[0]
or a[1]
was matched, and the application of .all(dim=-1)
will effectively bring us to a[0] in b
for the first element of the first dimension and a[1] in b
for the second element of the first dimension.
Then to get to a in b
simply .any(dim=0)
to combine the two measures providing tensor([False, True, False])
.