My question is nearly identical to this one, with the notable difference of being in PyTorch. I would prefer not to use the Numpy solution as this would involve moving data back to the CPU. I see that, as with Numpy, PyTorch has a nonzero function, however its where function (the solution in the Numpy thread I linked) has behavior different from Numpy's.
The behavior I want is an is_zero()
function as follows:
>>> arr.nonzero()
tensor([[0, 1],
[1, 0]])
>>> arr.is_zero()
tensor([[0, 0],
[1, 1]])