0

Very simple question but I have been struggling with this forever now.

import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap

I want:

torch.tensor([[True,False],[False,True]])

Both the tensor and overlap are very big, so efficiency is wished here.

Marcel Braasch
  • 1,083
  • 1
  • 10
  • 19

2 Answers2

1

The native way to do this is using torch.Tensor.apply_ method:

t.apply_(f)

However according to official doc it only works for tensors on CPU and discouraged for reaching high performance.

Besides it seems that there is not native torch function indicating if values of tensors are in a list and the only option should be to iterate over the list overlap. See here and here. Thus you can try:

sum(t==i for i in overlap).bool()

I found that the second function is more performant for big t and overlap and the first one for small t and overlap.

Valentin Goldité
  • 1,040
  • 4
  • 13
  • I figured, that there is no function for this, I am wondering why what is. Because for example, we do have atomic operations such as `add` which can be applied element wise. I don't understand why there exists no general function .. maybe I will have a look in the code how these are implemented. – Marcel Braasch Sep 08 '22 at 22:46
1

I found an easy way. Since torch is implemented through numpy array the following works and is performant:

import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)

Found here.

Marcel Braasch
  • 1,083
  • 1
  • 10
  • 19