5

I would like to sort a numpy array and find out where each element went.

numpy.argsort will tell me for each index in the sorted array, which index in the unsorted array goes there. I'm looking for something like the inverse: For each index in the unsorted array, where does it go in the sorted array.

a = np.array([1, 4, 2, 3])

# a sorted is [1,2,3,4]
# the 1 goes to index 0
# the 4 goes to index 3
# the 2 goes to index 1
# the 3 goes to index 2

# desired output
[0, 3, 1, 2]

# for comparison, argsort output
[0, 2, 3, 1]

A simple solution uses numpy.searchsorted

np.searchsorted(np.sort(a), a)
# produces [0, 3, 1, 2]

I'm unhappy with this solution, because it seems very inefficient. It sorts and searches in two separate steps.

This fancy indexing fails for arrays with duplicates, look at:

a = np.array([1, 4, 2, 3, 5])
print(np.argsort(a)[np.argsort(a)])
print(np.searchsorted(np.sort(a),a))


a = np.array([1, 4, 2, 3, 5, 2])
print(np.argsort(a)[np.argsort(a)])
print(np.searchsorted(np.sort(a),a))
lhk
  • 27,458
  • 30
  • 122
  • 201

2 Answers2

4

You can just use argsort twice on the list. At first the fact that this works seems a bit confusing, but if you think about it for a while it starts to make sense.

a = np.array([1, 4, 2, 3])
argSorted = np.argsort(a) # [0, 2, 3, 1]
invArgSorted = np.argsort(argSorted) # [0, 3, 1, 2]
markuscosinus
  • 2,248
  • 1
  • 8
  • 19
  • 1
    Did you forget tho pass something to the second call of `argsort`? You could improve the answer by explaining how it works and why it make sense :) – MB-F Jan 31 '19 at 11:55
  • I and others have suggested this double argsort without elaborate explanation. https://stackoverflow.com/q/54388972/901925 – hpaulj Jan 31 '19 at 12:05
2

You just need to invert the permutation that sorts the array. As shown in the linked question, you can do that like this:

import numpy as np

def sorted_position(array):
    a = np.argsort(array)
    a[a.copy()] = np.arange(len(a))
    return a

print(sorted_position([0.1, 0.2, 0.0, 0.5, 0.8, 0.4, 0.7, 0.3, 0.9, 0.6]))
# [1 2 0 5 8 4 7 3 9 6]
jdehesa
  • 58,456
  • 7
  • 77
  • 121