0

I have an N dimensional array where N is a variable from which I want to take elements along a given set of axes.

The objective is similar to the question except that the solution in that one seems to work when the dimensions and the axes are fixed.

For example, suppose from a 3D array, we want to extract elements along axis 0 for every multi-index along the other two axes. If the value of N is known beforehand, this can be hard corded

import numpy as np
a = np.arange(12).reshape((2,3,2))
ydim = a.shape[1]
zdim = a.shape[2]
for y in range(ydim):
    for z in range(zdim):
        print(a[:,y,z])

which gives the output

[0 6]
[1 7]
[2 8]
[3 9]
[ 4 10]
[ 5 11]

Q: How can this be achieved when N and the axes are not known beforehand?

For a single axis, numpy.take or numpy.take_along_axis do the job. I am looking for a similar function but for multiple axes. A function, say, take_along_axes() which can be used as follows:

ax = [1,2] ## list of axes from which indices are taken
it = np.nditer(a, op_axes=ax, flags=['multi_index']) ## Every index along those axes
while not it.finished():
    print(np.take_along_axes(a,it.multi_index, axes=ax)
    it.iternext()

The expected output is the same as the previous one.

user137846
  • 105
  • 6

0 Answers0