Consider a function that operates on an n-d array. I want to modify it to take an (n+1)-d array, allow me to specify an axis, and have it perform operations on all the axis's n-d slices. For example, I want to take functions like
def some_operation(arr):
#does something to the n-d array
return result
and let higher dimensional arrays be worked on like
def f(arr,axis=0):
# the array is (n+1)-d, so we operate on a given axis
n = range(arr.shape[axis])
if axis == 0:
result = np.array([some_operation(arr[j,:,:,...,:]) for j in n])
elif axis == 1:
result = np.array([some_operation(arr[:,j,:,...,:]) for j in n])
# and so on
return result
but with a helper function or decorator to streamline the code like
@along_axis
def some_operation(arr)
#does same thing to a n-d array
# but also now works on (n+1)-d arrays
return result
I think what I'm looking for is like np.apply_along_axis()
as documented here, but with the ability to handle functions of n-d arrays. I also think this post on how to dynamically slice a specified axis can be used to handle some of the indexing issues seen above. One thing for sure is I'd strongly prefer the operations avoid a copy. I'm unsure if the list comprehension I'm using above is the best way to do things.
Here's a working example of what I'm shooting for. First, here's a wrapper I made
import numpy as np
from functools import wraps
def along_axis(func):
"""
Wraps a function of a 2d array. Enables performing that function's operations along all 2d slices down an axis of a 3d array. When given a 2d array, the axis parameter is ignored.
"""
@wraps(func)
def wrapper(arr, *args, axis=0, **kwargs):
if arr.ndim == 3:
arr = np.moveaxis(arr,axis,-1)
return np.array([func(arr[:,:,t],*args,**kwargs) for t in range(arr.shape[-1])])
return func(arr,*args,**kwargs)
return wrapper
Now I'll apply it to a function that is meant to act on 2d arrays. This one in particular is inspired by this post on radial means of images.
@along_axis
def radial_mean(image):
"""
Computes radial means of a 2d image. After wrapping it, we can input 3d arrays as well.
"""
# create a radial mesh at an image's center
X,Y = np.meshgrid(np.arange(image.shape[1]),np.arange(image.shape[0]))
R = np.sqrt((X - image.shape[1]//2)**2 + (Y - image.shape[0]//2)**2)
# for each r in R
r = np.arange(int(R.max()))
# compute the radial mean
f = np.vectorize(lambda r : image[(R >= r-0.5) & (R < r+0.5)].mean())
return f(r)
Now I can compute the radial means for many images or just a single image like before.
images = np.random.rand(1000,16,16)
print(radial_mean(images, axis=0).shape)
print(radial_mean(images[0], axis=0).shape)
> (100,11)
> (11,)
Does this approach make sense? Is there a better way? Thank you.