Suppose that I have an array which has mxn dimensions.
How do I do the argmax using numpy over the last n dimensions?
So the output array should, given the first m indices, return a list of n indices that correspond the the maximal value of array[m indices].
For example:
Input:
m = 1
n = 2
array = [[[3,1],[2,2]],[[1,2],[2,4]]]
Output:
[[0,0], [1,1]]
Where these correspond to 3 as a maximum of [[3,1],[2,2]]
and 4 as a maximum of [[1,2],[2,4]]
.
Please note that m and n are stored in variables and change from case to case.