3

I have 2d numpy array of size ~70k * 10k. I want to replace all values with zero which are smaller than the "N" largest element in every row. For example:

arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

For N = 3 the result should be:

result = np.array([[0, 0, 6, 5, 0, 5], # 3 largest in row: 6, 5, 5
                   [7, 0, 0, 6, 7, 0], 
                   [0, 5, 0, 5, 6, 0]])

The positions of numbers that were not replaced and the shape of the array should stay the same.

MSeifert
  • 145,886
  • 38
  • 333
  • 352
  • 5
    did you try anything yourself? Got any ideas as to how one would go about doing it? – Ma0 Aug 22 '17 at 11:20
  • There are similar questions like https://stackoverflow.com/questions/30332908/n-largest-values-in-each-row-of-ndarray but np.partition changes array shape, that's what I must avoid. – Anastasia Manokhina Aug 22 '17 at 11:31
  • 1
    The 3rd google result I got was [this](https://stackoverflow.com/questions/19666626/replace-all-elements-of-python-numpy-array-that-are-greater-than-some-value). It should be enough to get you started. – Ma0 Aug 22 '17 at 11:33
  • Your 3rd google result is about one particular value (255). I need to determine N largest value for every row independently because in each row they can be different. I would like to find the most elegant solution. – Anastasia Manokhina Aug 22 '17 at 11:40
  • If it was exactly what you wanted I would mark your question as a duplicate. I said _to get you started_ – Ma0 Aug 22 '17 at 11:45
  • I wouldn't ask this question if I didn't _start_ yet. – Anastasia Manokhina Aug 22 '17 at 11:51
  • 3
    What should the result be for that same array but N=2? That is, how to deal with duplicates? – Stefan Pochmann Aug 22 '17 at 11:55
  • All duplicates should remain if at least one of them is in top 20 in sorted order. – Anastasia Manokhina Aug 23 '17 at 15:49

1 Answers1

4

You could find the N-th largest value using np.partition and then just use boolean indexing to replace everything that's "below" that value in it's row:

import numpy as np
arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

N = 3
nlargest = np.partition(arr, -N, axis=1)[:, -N]
arr[arr < nlargest[:, None]] = 0
arr
# array([[0, 0, 6, 5, 0, 5],
#        [7, 0, 0, 6, 7, 0],
#        [0, 5, 0, 5, 6, 0]])
MSeifert
  • 145,886
  • 38
  • 333
  • 352