1

I have some pytorch tensors (or numpy arrays) and want to turn the two highest numbers to a 1 and every other number to a zero. So this tensor

tensor([0.9998, 0.9997, 0.9991, 0.9998, 0.9996, 0.9996, 0.9997, 0.9995],
   dtype=torch.float64)

Should become this:

tensor([1, 0, 0, 1, 0, 0, 0, 0],
   dtype=torch.float64)

There are some ways to turn the highest number to 1 and the others to 0, but I need the two highest numbers to become 1. Why isn't there a built in function for that? I have a classification problem where I know that two objects belong to class 1. So shouldn't this be a quite common problem?

iacob
  • 20,084
  • 6
  • 92
  • 119
thyhmoo
  • 313
  • 2
  • 15

1 Answers1

1

You can do this with topk:

x = tensor([0.9998, 0.9997, 0.9991, 0.9998, 0.9996, 0.9996, 0.9997, 0.9995], dtype=torch.float64)

_, idx = x.topk(2)

x.fill_(0)
x[idx] = 1
iacob
  • 20,084
  • 6
  • 92
  • 119