You could sort the array, but as of NumPy 1.8, there is a faster way to find the N largest values (particularly when data
is large):
Using numpy.argpartition:
import numpy as np
data = np.array([[[ 512, 520, 1, 130523]],
[[ 520, 614, 573, 7448]],
[[ 614, 616, 615, 210]],
[[ 616, 622, 619, 269]],
[[ 622, 624, 623, 162]],
[[ 625, 770, 706, 8822]],
[[ 770, 776, 773, 241]]])
idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])
yields
[[[ 520 614 573 7448]]
[[ 512 520 1 130523]]
[[ 625 770 706 8822]]]
np.argpartition
performs a partial sort. It returns the indices of the array in partially sorted order, such that every kth
item is in its final sorted position. In effect, every group of k
items is sorted relative to the other groups, but each group itself is not sorted (thus saving some time).
Notice that the 3 highest rows are not returned in a same order as they appeared in data
.
For comparison, here is how you could find the 3 highest rows by using np.argsort
(which performs a full sort):
idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])
yields
[[[ 520 614 573 7448]]
[[ 625 770 706 8822]]
[[ 512 520 1 130523]]]
Note: np.argsort
is faster for small arrays:
In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop
In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop
But np.argpartition
is faster for large arrays:
In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)
In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop
In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop