I am trying to find the max-indices of a 3d array over (0, 1) axis.
I tried using argwhere with amax, but it returns multiple indices over the axis if it have multiple max values
x = [[[ 4 5 9]
[ 0 13 6]]
[[12 11 13]
[ 5 13 8]]]
np.amax(x, axis=(0,1))
>> [12 13 13] #I just want the indices of the following max values
x==np.amax(x, axis=(0,1))
#this seems to be the problem, multiple True in the second column
v
>>[[[False False False]
[False True False]]
[[True False True]
[False True False]]]
np.argwhere(x==np.amax(x, axis=(0,1))) #this should return 3 indices instead of 4
>>[[0 1 1]
[1 0 0]
[1 0 2]
[1 1 1]]
So is there any 'numpy way' to get unique max-indices of a 3d array over (0, 1) axis.