0

Suppose that I have an array which has mxn dimensions.

How do I do the argmax using numpy over the last n dimensions?

So the output array should, given the first m indices, return a list of n indices that correspond the the maximal value of array[m indices].

For example:

Input:

m = 1
n = 2
array = [[[3,1],[2,2]],[[1,2],[2,4]]]

Output:

[[0,0], [1,1]] 

Where these correspond to 3 as a maximum of [[3,1],[2,2]] and 4 as a maximum of [[1,2],[2,4]].

Please note that m and n are stored in variables and change from case to case.

Raghul Raj
  • 1,428
  • 9
  • 24
Light
  • 31
  • 3

1 Answers1

1

The Numpy's argmax has an option to input the axis. In your case MxN is always two dimensional. Hence this should do the trick:

m = 1
n = 2
array = [[[3,1],[2,2]],[[1,2],[2,4]],[[1,2],[7,4]]]

np.argmax(array,axis=2)
>>array([[0, 0],[1, 1]], dtype=int64)
Raghul Raj
  • 1,428
  • 9
  • 24
  • I think you did not understand the question. I need a function argmax_over_axes(array, m, n), not a solution to a specific example. My program should run with cases where for example M = 10 and N= 15. There MxN is not 2. – Light Jun 07 '20 at 17:27