10

I have a N-dimensional array (Named A). For each row of the first axis of A, I want to obtain the coordinates of the maximum value along the other axes of A. Then I would return a 2-dimensional array with the coordinates of the maximum value for each row of the first axis of A.

I already solved my problem using a loop, but I was wondering whether there is a more efficient way of doing this. My current solution (for an example array A) is as follows:

import numpy as np

A=np.reshape(np.concatenate((np.arange(0,12),np.arange(0,-4,-1))),(4,2,2))
maxpos=np.empty(shape=(4,2))
for n in range(0, 4):
    maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:]), A[n,:,:].shape)

Here, we would have:

A: 
[[[ 0  1]
  [ 2  3]]

 [[ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]]

 [[ 0 -1]
  [-2 -3]]]

maxpos:
[[ 1.  1.]
 [ 1.  1.]
 [ 1.  1.]
 [ 0.  0.]]

If there are multiple maximizers, I don't mind which is chosen.

I have tried to use np.apply_over_axes, but I haven't managed to make it return the outcome I want.

Divakar
  • 218,885
  • 19
  • 262
  • 358
CarlosH
  • 319
  • 1
  • 4
  • 8

2 Answers2

13

You could do something like this -

# Reshape input array to a 2D array with rows being kept as with original array.
# Then, get idnices of max values along the columns.
max_idx = A.reshape(A.shape[0],-1).argmax(1)

# Get unravel indices corresponding to original shape of A
maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))

Sample run -

In [214]: # Input array
     ...: A = np.random.rand(5,4,3,7,8)

In [215]: # Setup output array and use original loopy code
     ...: maxpos=np.empty(shape=(5,4)) # 4 because ndims in A is 5
     ...: for n in range(0, 5):
     ...:     maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:,:,:]), A[n,:,:,:,:].shape)
     ...:     

In [216]: # Proposed approach
     ...: max_idx = A.reshape(A.shape[0],-1).argmax(1)
     ...: maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
     ...: 

In [219]: # Verify results
     ...: np.array_equal(maxpos.astype(int),maxpos_vect)
Out[219]: True

Generalize to n-dim array

We could generalize to solve for n-dim array to get argmax for last N axes combined with something like this -

def argmax_lastNaxes(A, N):
    s = A.shape
    new_shp = s[:-N] + (np.prod(s[-N:]),)
    max_idx = A.reshape(new_shp).argmax(-1)
    return np.unravel_index(max_idx, s[-N:])

The result would a tuple of arrays of indices. If you need the final output as an array, we can use np.stack or np.concatenate.

Divakar
  • 218,885
  • 19
  • 262
  • 358
  • This answer was linked in https://stackoverflow.com/questions/62105979/how-to-find-argmax-of-last-2-axes. `maxpos` is what your OP asks for, but how is it useful? – hpaulj Jun 01 '20 at 17:20
  • @hpaulj Not sure what you are getting at by "usefulness". Maxpos values are argmax values for last two axes combined for a 3D array input with indices referred back to the shape of last two axes. Does this answer your question? Or are you asking how this applies to n-dim array? – Divakar Jun 01 '20 at 18:08
2

You can use a list comprehension

result = [np.unravel_index(np.argmax(r), r.shape) for r in a]

it's IMO more readable but the speed is going to be not much better than an explicit loop.

The fact that the main outer loop is in Python should matter only if the first dimension is actually the very big one.

If this is the case (i.e. you have ten millions of 2x2 matrices) then flipping is faster...

# true if 0,0 is not smaller than others
m00 = ((data[:,0,0] >= data[:,1,0]) &
       (data[:,0,0] >= data[:,0,1]) &
       (data[:,0,0] >= data[:,1,1]))

# true if 0,1 is not smaller than others
m01 = ((data[:,0,1] >= data[:,1,0]) &
       (data[:,0,1] >= data[:,0,0]) &
       (data[:,0,1] >= data[:,1,1]))

# true if 1,0 is not smaller than others
m10 = ((data[:,1,0] >= data[:,0,0]) &
       (data[:,1,0] >= data[:,0,1]) &
       (data[:,1,0] >= data[:,1,1]))

# true if 1,1 is not smaller than others
m11 = ((data[:,1,1] >= data[:,1,0]) &
       (data[:,1,1] >= data[:,0,1]) &
       (data[:,1,1] >= data[:,0,0]))

# choose which is max on equality
m01 &= ~m00
m10 &= ~(m00|m01)
m11 &= ~(m00|m01|m10)

# compute result
result = np.zeros((len(data), 2), np.int32)
result[:,1] |= m01|m11
result[:,0] |= m10|m11

On my machine the code above is about 50 times faster (for one million of 2x2 matrices).

6502
  • 112,025
  • 15
  • 165
  • 265