1

Given an array of 2, 9x9 images with 2 channels shaped like this:

img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])

I need to slice each image into 3x3x2 (stride=3) regions and then locate and replace max elements of each slice. For the example above these elements are:

  • (:, 2, 2, :)
  • (:, 2, 5, :)
  • (:, 2, 8, :)
  • (:, 5, 2, :)
  • (:, 5, 5, :)
  • (:, 5, 8, :)
  • (:, 8, 2, :)
  • (:, 8, 5, :)
  • (:, 8, 8, :)

So far my solution is this:

batch_size, _, _, channels = batch.shape
region_size = 3

# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)

# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)

region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)

# Manually unravel indices
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

batch[region_slice][region_argmax] = new_values

There are two problems with this code:

  • Reshaping region may return a copy instead of view
  • Manual unraveling

What is the better way to perform this operation?

Kentzo
  • 3,881
  • 29
  • 54

1 Answers1

2

With merging axes

The better way (on memory and hence performance efficiency) is to use advanced-indexing to create the appropriate indexing tuple -

m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array

The last step could be written differently by reshaping the replacement-array to the indexed shape (if not already so, otherwise skip) -

batch_3d[indexer] = replacement_array.reshape(m,n)

We can also use the built-in np.put_along_axis with p as the replacement array -

np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)

Note: The idx used in this post is the one generated from : idx = batch_3d.argmax(axis=1), hence we are skipping the manually unravel indices step.


Without merging axes

We would define helper funcs to achieve our argmax based replacements along multiple axes without merging axes that are not adjacent, as they will force copy.

def indexer_skip_one_axis(a, axis):
    return tuple(slice(None) if i!=axis else None for i in range(a.ndim))

def argmax_along_axes(a, axis):
    # a is input array
    # axis is tuple of axes along which argmax indices are to be computed
    argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
    val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
    argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
    val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
    r = list(np.ix_(*[np.arange(i) for i in a.shape]))
    r[axis[0]] = val_argmax2
    r[axis[1]] = argmax2
    return tuple(r)

Hence, to solve our case to do all the replacements would be -

m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)
Community
  • 1
  • 1
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • I have updated the Q with the shape of the new_values array. Please elaborate on your code samples. The intention of some lines is not clear to me... – Kentzo Apr 12 '19 at 23:00
  • @Kentzo Just use the last two suggested solutions. – Divakar Apr 12 '19 at 23:04
  • @Kentzo Did that make sense? – Divakar Apr 12 '19 at 23:26
  • I need to mediate over the meaning of put_along_axis's arguments :) – Kentzo Apr 12 '19 at 23:29
  • Is there a way to avoid `batch_3d`? It looks like that depending on the shape `reshape` may return a copy and modification of `batch_3d` won't be reflected in `batch`. – Kentzo Apr 13 '19 at 18:21
  • @Kentzo Nope, `reshape` won't return a copy. How did you reach that conclusion? – Divakar Apr 13 '19 at 18:38
  • @Kentzo .. given `batch` is a contiguous array - https://stackoverflow.com/questions/26998223/what-is-the-difference-between-contiguous-and-non-contiguous-arrays, https://docs.scipy.org/doc/numpy/reference/generated/numpy.ascontiguousarray.html – Divakar Apr 13 '19 at 18:49
  • I realized that my example did not demonstrate the complete case. Please see the updated question. – Kentzo Apr 13 '19 at 19:15
  • Thanks. Also updated the Q with changed I made in the meanwhile. – Kentzo Apr 15 '19 at 21:34