The existing answers are correct, but I wanted to expand on them to provide a self-contained function that behaves exactly like torch.topk
with pure numpy
.
Here's the function (I've included the instructions inline):
def topk(array, k, axis=-1, sorted=True):
# Use np.argpartition is faster than np.argsort, but do not return the values in order
# We use array.take because you can specify the axis
partitioned_ind = (
np.argpartition(array, -k, axis=axis)
.take(indices=range(-k, 0), axis=axis)
)
# We use the newly selected indices to find the score of the top-k values
partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
if sorted:
# Since our top-k indices are not correctly ordered, we can sort them with argsort
# only if sorted=True (otherwise we keep it in an arbitrary order)
sorted_trunc_ind = np.flip(
np.argsort(partitioned_scores, axis=axis), axis=axis
)
# We again use np.take_along_axis as we have an array of indices that we use to
# decide which values to select
ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
else:
ind = partitioned_ind
scores = partitioned_scores
return scores, ind
To verify the correctness, you can test it against torch:
import torch
import numpy as np
x = np.random.randn(50, 50, 10, 10)
axis = 2 # Change this to any axis and it'll be fine
val_np, ind_np = topk(x, k=10, axis=axis)
val_pt, ind_pt = torch.topk(torch.tensor(x), k=10, dim=axis)
print("Values are same:", np.all(val_np == val_pt.numpy()))
print("Indices are same:", np.all(ind_np == ind_pt.numpy()))
- To be clear,
np.take_along_axis
is recommended to be used with np.argpartition
for accessing the original value in the higher-dimension.
np.argpartition
is faster than np.argsort
because it does not sort the entire array. This answer claims it takes O(n)
instead of `O(n log