0

I've been struggling with this problem in various guises for a long time, and never managed to find a good solution.

Basically if I want to write a function that performs an operation over a given, but arbitrary axis of an arbitrary rank array, in the style of (for example) np.mean(A,axis=some_axis), I have no idea in general how to do this.

The issue always seems to come down to the inflexibility of the slicing syntax; if I want to access the ith slice on the 3rd index, I can use A[:,:,i], but I can't generalise this to the nth index.

JoshD
  • 47
  • 1
  • 10
  • 1
    Expanding on @hpaulj's first suggestion, maybe unexpectedly the result after swapaxes->some operation creating new array->swapaxes of new array is often contiguous, see for example [ħere](https://stackoverflow.com/a/47861800/7207392) – Paul Panzer Oct 25 '18 at 16:58

2 Answers2

1

numpy functions use several approaches to do this:

  • transpose axes to move the target axis to a known position, usually first or last; and if needed transpose the result

  • reshape (along with transpose) to reduce the problem simpler dimensions. If your focus is on the n'th dimension, it might not matter where the (:n) dimension are flattened or not. They are just 'going along for the ride'.

  • construct an indexing tuple. idx = (slice(None), slice(None), j); A[idx] is the equivalent of A[:,:,j]. Start with a list or array of the right size, fill with slices, fiddle with it, and then convert to a tuple (tuples are immutable).

  • Construct indices with indexing_tricks tools like np.r_, np.s_ etc.

Study code that provides for axes. Compiled ufuncs won't help, but functions like tensordot, take_along_axis, apply_along_axis, np.cross are written in Python, and use one or more of these tricks.

hpaulj
  • 221,503
  • 14
  • 230
  • 353
0

You cannot generalize this. In fact, the example numpy.mean(a, axis=axis_index) is good to look at in this case. Even in numpy, which is written mostly in C, loops through the axis indexs to know where to compute the mean. Have a look at reduction.c which is in the core of numpy.mean. Even though they format the data in an advantegous manner before performing operations, to loop through all the axis with your axis_index is always required.

b-fg
  • 3,959
  • 2
  • 28
  • 44