2

I have an index tensor of size (2, 3):

>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
        [3., 4., 7.]])

And a value tensor of size (2, 8):

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

I want to set the element in value to 1 by the index along dim=-1.** The output should be like:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

I tried value[range(2), index] = 1 but it triggers an error. I also tried torch.index_fill but it doesn't accept batched indices. torch.scatter requires creating an extra tensor of size 2*8 full of 1, which consumes unnecessary memory and time.

Ivan
  • 34,531
  • 8
  • 55
  • 100
namespace-Pt
  • 1,604
  • 1
  • 14
  • 25

1 Answers1

3

You can actually use torch.Tensor.scatter_ by setting the value (int) option instead of the src option (Tensor).

>>> value.scatter_(dim=-1, index=index.long(), value=1)

>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

Make sure the index is of type int64 though.

Ivan
  • 34,531
  • 8
  • 55
  • 100