0

I want to extend following question with particular concern:

How to obtain the argmax of a[...] in proper a indices

>>> a = (np.random.random((10, 10))*10).astype(int)
>>> a
array([[4, 1, 7, 4, 3, 3, 8, 9, 3, 0],
       [7, 7, 8, 9, 9, 6, 1, 4, 2, 0],
       [6, 9, 4, 9, 2, 7, 9, 0, 8, 6],
       [2, 4, 7, 8, 0, 6, 0, 7, 1, 8],
       [7, 9, 7, 0, 1, 2, 3, 7, 9, 6],
       [7, 1, 1, 0, 5, 1, 8, 8, 5, 5],
       [5, 4, 3, 0, 0, 4, 4, 5, 5, 4],
       [9, 5, 0, 5, 8, 1, 6, 4, 8, 5],
       [5, 8, 0, 8, 2, 6, 4, 9, 5, 1],
       [2, 5, 0, 1, 4, 0, 0, 9, 6, 4]])
>>> np.unravel_index(a.argmax(), a.shape)
(0, 7)
>>> np.unravel_index(a[a>5].argmax(), a.shape)
(0, 2)
>>> np.unravel_index(a[a>5].argmax(), a[a>5].shape)
(2,)
majkrzak
  • 1,332
  • 3
  • 14
  • 30

2 Answers2

1

For a mask, what about:

np.where( (a > 5) & (a == a[a>5].max()))

or

mask = a > 5
np.where( mask & (a == a[mask].max()))
Learning is a mess
  • 7,479
  • 7
  • 35
  • 71
  • This is the current approach I'm using, but it most likely results in doubling the runtime. Or am I wrong, and there is some fancy optimization behind? – majkrzak May 30 '23 at 13:02
  • There is 4 passes through the input array, indeed. How big is your typical input? How fast do you need to be? – Learning is a mess May 30 '23 at 13:09
  • About 120 GB before loading so constant in the complexity matters. I guess, it may be fastest and memory efficient to do the argmax manually here – majkrzak May 30 '23 at 13:16
  • I tried a `argmax, valmax = reduce(lambda acc, itx: itx if acc[1] – majkrzak May 30 '23 at 16:32
1

You could consider the masked API:

import numpy as np

arr = np.random.randint(10, 100, size=(10, 10))
mask = arr > 50

# Note: values True in `mask` are considered "invalid"
# or "masked", and thus disregarded. This is opposite
# the behavior in boolean mask indexing, where only
# the True values are retrieved.

masked = np.ma.array(arr, mask=mask)
out = np.unravel_index(masked.argmax(), masked.shape)

results:

>>> arr
array([[58, 75, 78, 46, 89, 54, 35, 18, 13, 99],
       [30, 11, 24, 10, 15, 41, 40, 15, 94, 28],
       [84, 84, 83, 72, 39, 22, 57, 51, 91, 23],
       [54, 99, 72, 63, 30, 14, 91, 46, 98, 74],
       [27, 90, 93, 25, 41, 82, 39, 42, 57, 64],
       [98, 63, 79, 13, 91, 12, 36, 71, 95, 30],
       [23, 34, 51, 19, 37, 31, 58, 65, 20, 31],
       [26, 73, 67, 21, 67, 89, 72, 80, 11, 48],
       [87, 64, 38, 74, 60, 31, 30, 54, 71, 44],
       [78, 94, 62, 38, 79, 23, 61, 62, 18, 25]])
>>> print(masked)
[[-- -- -- 46 -- -- 35 18 13 --]
 [30 11 24 10 15 41 40 15 -- 28]
 [-- -- -- -- 39 22 -- -- -- 23]
 [-- -- -- -- 30 14 -- 46 -- --]
 [27 -- -- 25 41 -- 39 42 -- --]
 [-- -- -- 13 -- 12 36 -- -- 30]
 [23 34 -- 19 37 31 -- -- 20 31]
 [26 -- -- 21 -- -- -- -- 11 48]
 [-- -- 38 -- -- 31 30 -- -- 44]
 [-- -- -- 38 -- 23 -- -- 18 25]]
>>> out
(7, 9)
>>> arr[out]
48
Chrysophylaxs
  • 5,818
  • 3
  • 10
  • 21