1

I have the following data structure:

 [[[   512    520     1 130523]]

 [[   520    614    573   7448]]

 [[   614    616    615    210]]

 [[   616    622    619    269]]

 [[   622    624    623    162]]

 [[   625    770    706   8822]]

 [[   770    776    773    241]]]

I'm trying to return an object of the same shape, but only returning the rows with the 3 largest 4th columns (if that makes sense) (so in this case, that would be rows 1, 2 & 6)

What's the most elegant way to do this?

unutbu
  • 842,883
  • 184
  • 1,785
  • 1,677
cjm2671
  • 18,348
  • 31
  • 102
  • 161

3 Answers3

5

You can use sorted() and specify that you want to sort by the 4th column:

l = [[[512,    520 ,    1, 130523]],
 [[   520 ,   614  ,  573,   7448]],
 [[   614 ,   616  ,  615,    210]],
 [[   616 ,   622  ,  619,    269]],
 [[   622 ,   624  ,  623,    162]],
 [[   625 ,   770  ,  706,   8822]],
 [[   770 ,   776  ,  773,    241]]]

top3 =  sorted(l, key=lambda x: x[0][3], reverse=True)[:3]

print top3

will give you:

[[[512, 520, 1, 130523]], [[625, 770, 706, 8822]], [[520, 614, 573, 7448]]]
jh314
  • 27,144
  • 16
  • 62
  • 82
3

You could sort the array, but as of NumPy 1.8, there is a faster way to find the N largest values (particularly when data is large):

Using numpy.argpartition:

import numpy as np
data = np.array([[[ 512,    520,     1, 130523]],
                 [[ 520,    614,    573,   7448]],
                 [[ 614,    616,    615,    210]],
                 [[ 616,    622,    619,    269]],
                 [[ 622,    624,    623,    162]],
                 [[ 625,    770,    706,   8822]],
                 [[ 770,    776,    773,    241]]])

idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])

yields

[[[   520    614    573   7448]]

 [[   512    520      1 130523]]

 [[   625    770    706   8822]]]

np.argpartition performs a partial sort. It returns the indices of the array in partially sorted order, such that every kth item is in its final sorted position. In effect, every group of k items is sorted relative to the other groups, but each group itself is not sorted (thus saving some time).

Notice that the 3 highest rows are not returned in a same order as they appeared in data.


For comparison, here is how you could find the 3 highest rows by using np.argsort (which performs a full sort):

idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])

yields

[[[   520    614    573   7448]]

 [[   625    770    706   8822]]

 [[   512    520      1 130523]]]

Note: np.argsort is faster for small arrays:

In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop

In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop

But np.argpartition is faster for large arrays:

In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)

In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop

In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop
unutbu
  • 842,883
  • 184
  • 1,785
  • 1,677
  • This is great. I did find though that in some cases I wouldn't get things in the right order; is there anything in this method that doesn't guarantee perfect ordering? – cjm2671 Aug 22 '14 at 13:31
  • This should work for any data whose dtype is a floating or int dtype. When using `np.argpartition(-data, 3)`, (note the minus sign) `NaNs` will be sorted at the end, so the highest values will be finite (if any exist). Note also that if `data`'s dtype is `object`, then the NaNs get sorted to the beginning, so then the highest values might contain NaNs. I'm not sure if this is answering your question. An example where things are not sorted in the right order would be very helpful. – unutbu Aug 22 '14 at 14:09
0

I simplified the structure of your list-of-lists in order to focus on the main issue. You can use sorted() with a customized compare() function:

my_list =  [[512, 520, 1, 130523], 
        [520, 614 , 573, 7448],
        [614, 616, 615, 210],
        [616, 622, 619, 269], 
        [622, 624, 623, 162], 
        [625, 770, 706, 8822], 
        [770, 776, 773, 241]]

def sort_by(a):
    return a[3]

sorted(my_list, key=sort_by)
print my_list[0:3] # prints [[512, 520, 1, 130523], [520, 614, 573, 7448], [614, 616, 615, 210]]
Nir Alfasi
  • 53,191
  • 11
  • 86
  • 129