5

I want to get the index of the largest n values in a multi-dimensional numpy array. For get the index of the largest n values in a one-dimensional numpy array, i found this. After test in interactive shell in python, it seems that bottleneck.argpartsort can't effect on multi-dimensional numpy array. For get the index of the largest value in a multi-dimensional numpy array, i found this. It can't get the largest n. The method that i can give is translate the multi-dimensional numpy array to a list of {value:index}(index present by a tuple), and then sort the list by the value, and get the index for it. Is there anything more easier or more performance?

Community
  • 1
  • 1
stamaimer
  • 6,227
  • 5
  • 34
  • 55
  • 2
    Reshape to one dimension, then search, then get the original indexes through arithmetic calculations involving the dimensions before the reshape? – Jérôme Apr 22 '15 at 14:23
  • 1
    Maybe `flatten()` the original array, then use your 1D solution, finally calculate the real nD indices using the original shape? – chw21 Apr 22 '15 at 14:25
  • What is the problem with the second link? Can you show us a part of this multi-dimensional array? I don't see why it shouldn't work ... – plonser Apr 22 '15 at 14:44
  • http://stackoverflow.com/questions/26603747/get-the-indices-of-n-highest-values-in-an-ndarray – Lee Apr 22 '15 at 14:47
  • @plonser the second link only return the index of the maximum value of the multi-dimensional numpy array – stamaimer Apr 23 '15 at 00:12

1 Answers1

10

I don't have access to bottleneck, so in this example I am using argsort, but you should be able to use it in the same way:

#!/usr/bin/env python
import numpy as np
N = 4
a = np.random.random(20).reshape(4, 5)
print(a)

# Convert it into a 1D array
a_1d = a.flatten()

# Find the indices in the 1D array
idx_1d = a_1d.argsort()[-N:]

# convert the idx_1d back into indices arrays for each dimension
x_idx, y_idx = np.unravel_index(idx_1d, a.shape)

# Check that we got the largest values.
for x, y, in zip(x_idx, y_idx):
    print(a[x][y])
chw21
  • 7,970
  • 1
  • 16
  • 31
  • 1
    Use [`argpartition`](http://docs.scipy.org/doc/numpy/reference/generated/numpy.argpartition.html) instead of `argsort` and you'll earn my upvote. ;-) – Jaime Apr 22 '15 at 15:49