Given the task of indirectly sorting a subset (the top k, top meaning first in sort order) there are two builtin solutions: argsort
and argpartition
cf. @Divakar's answer.
If, however, performance is a consideration then it may (depending on the sizes of the data and the subset of interest) be well worth resisting the "lure of the one-liner", investing one more line and applying argsort
on the output of argpartition
:
>>> def top_k_sort(a, k):
... return np.argsort(a)[:k]
...
>>> def top_k_argp(a, k):
... return np.argpartition(a, range(k))[:k]
...
>>> def top_k_hybrid(a, k):
... b = np.argpartition(a, k)[:k]
... return b[np.argsort(a[b])]
>>> k = 100
>>> timeit.timeit('f(a,k)', 'a=rng((100000,))', number = 1000, globals={'f': top_k_sort, 'rng': np.random.random, 'k': k})
8.348663672804832
>>> timeit.timeit('f(a,k)', 'a=rng((100000,))', number = 1000, globals={'f': top_k_argp, 'rng': np.random.random, 'k': k})
9.869098862167448
>>> timeit.timeit('f(a,k)', 'a=rng((100000,))', number = 1000, globals={'f': top_k_hybrid, 'rng': np.random.random, 'k': k})
1.2305558240041137
argsort
is O(n log n), argpartition
with range argument appears to be O(nk) (?), and argpartition
+ argsort
is O(n + k log k)
Therefore in an interesting regime n >> k >> 1 the hybrid method is expected to be fastest
UPDATE: ND version:
import numpy as np
from timeit import timeit
def top_k_sort(A,k,axis=-1):
return A.argsort(axis=axis)[(*axis%A.ndim*(slice(None),),slice(k))]
def top_k_partition(A,k,axis=-1):
return A.argpartition(range(k),axis=axis)[(*axis%A.ndim*(slice(None),),slice(k))]
def top_k_hybrid(A,k,axis=-1):
B = A.argpartition(k,axis=axis)[(*axis%A.ndim*(slice(None),),slice(k))]
return np.take_along_axis(B,np.take_along_axis(A,B,axis).argsort(axis),axis)
A = np.random.random((100,10000))
k = 100
from timeit import timeit
for f in globals().copy():
if f.startswith("top_"):
print(f, timeit(f"{f}(A,k)",globals=globals(),number=10)*100)
Sample run:
top_k_sort 63.72379460372031
top_k_partition 99.30561298970133
top_k_hybrid 10.714635509066284