4

I would like to do something like an argmax but with multiple top values. I know how to use the normal torch.argmax

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398,  1.2663, -0.2686,  0.2450],
        [-0.7401, -0.8805, -0.3402, -1.1936],
        [ 0.4907, -1.3948, -1.0691, -0.3132],
        [-1.6092,  0.5419, -0.2993,  0.3195]])
>>> torch.argmax(a)
tensor(0)

But now I need to find the indices for the top N values. So something like this

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398,  1.2663, -0.2686,  0.2450],
        [-0.7401, -0.8805, -0.3402, -1.1936],
        [ 0.4907, -1.3948, -1.0691, -0.3132],
        [-1.6092,  0.5419, -0.2993,  0.3195]])
>>> torch.argmax(a,top_n=2)
tensor([0,1])

I haven't found any function capable of doing this in pytorch, does anyone know?

eljiwo
  • 687
  • 1
  • 8
  • 29

2 Answers2

7

Great! So you need the first k largest elements of a tensor.

[Answer 1] You need the first k largest of all the elements irrespective of the dimension. So, flatten the tensor and use the torch.topk function to get indices of top-3 (for example) elements:

>>> a = torch.randn(5,4)
>>> a
tensor([[ 0.8292, -0.5123, -0.0741, -0.3043],
        [-0.4340, -0.7763,  1.9716, -0.5620],
        [ 0.1582, -1.2000,  1.0202, -1.5202],
        [-0.3617, -0.2479,  0.6204,  0.2575],
        [ 1.8025,  1.9864, -0.8013, -0.7508]])
>>> torch.topk(a.flatten(), 3).indices
tensor([17,  6, 16])

[Answer 2] You need k largest elements of the given input tensor along a given dimension. So for this refer to the PyTorch documentation of the function torch.topk given here.

4

I believe you are looking for torch.topk

trialNerror
  • 3,255
  • 7
  • 18