0

Consider a function that operates on an n-d array. I want to modify it to take an (n+1)-d array, allow me to specify an axis, and have it perform operations on all the axis's n-d slices. For example, I want to take functions like

def some_operation(arr):
    #does something to the n-d array
    return result

and let higher dimensional arrays be worked on like

def f(arr,axis=0):
   # the array is (n+1)-d, so we operate on a given axis

   n = range(arr.shape[axis])

   if axis == 0:
       result = np.array([some_operation(arr[j,:,:,...,:]) for j in n])
   elif axis == 1:
       result = np.array([some_operation(arr[:,j,:,...,:]) for j in n])
   # and so on

   return result

but with a helper function or decorator to streamline the code like

@along_axis
def some_operation(arr)
    #does same thing to a n-d array
    # but also now works on (n+1)-d arrays
    return result

I think what I'm looking for is like np.apply_along_axis() as documented here, but with the ability to handle functions of n-d arrays. I also think this post on how to dynamically slice a specified axis can be used to handle some of the indexing issues seen above. One thing for sure is I'd strongly prefer the operations avoid a copy. I'm unsure if the list comprehension I'm using above is the best way to do things.

Here's a working example of what I'm shooting for. First, here's a wrapper I made

import numpy as np
from functools import wraps

def along_axis(func):
"""
Wraps a function of a 2d array. Enables performing that function's operations along all 2d slices down an axis of a 3d array. When given a 2d array, the axis parameter is ignored.
"""
    @wraps(func)
    def wrapper(arr, *args, axis=0, **kwargs):
        if arr.ndim == 3:
            arr = np.moveaxis(arr,axis,-1)
            return np.array([func(arr[:,:,t],*args,**kwargs) for t in range(arr.shape[-1])])
        return func(arr,*args,**kwargs)
    return wrapper

Now I'll apply it to a function that is meant to act on 2d arrays. This one in particular is inspired by this post on radial means of images.

@along_axis
def radial_mean(image):
"""
Computes radial means of a 2d image. After wrapping it, we can input 3d arrays as well.
"""
    # create a radial mesh at an image's center
    X,Y = np.meshgrid(np.arange(image.shape[1]),np.arange(image.shape[0]))
    R = np.sqrt((X - image.shape[1]//2)**2 + (Y - image.shape[0]//2)**2)
    
    # for each r in R
    r  = np.arange(int(R.max()))
    
    # compute the radial mean
    f = np.vectorize(lambda r : image[(R >= r-0.5) & (R < r+0.5)].mean())
    return f(r)

Now I can compute the radial means for many images or just a single image like before.

images = np.random.rand(1000,16,16)
print(radial_mean(images, axis=0).shape)
print(radial_mean(images[0], axis=0).shape)

> (100,11)
> (11,)

Does this approach make sense? Is there a better way? Thank you.

suneater
  • 1
  • 1

1 Answers1

1

You can just build an index tuple by filling the places before and after the axis you work along with default slice objects.

axis = 0
n = range(arr.shape[axis])
preidx = [slice(None, None) for i in range(axis)]
postidx = [slice(None, None) for i in range(len(arr.shape)-axis-1)]
result = np.array([some_operation(arr[(*preidx, j, *postidx)]) for j in n])
yann ziselman
  • 1,952
  • 5
  • 21