What about tuple()
indexing on first axis?
E.g.:
arr[(0, 2, 3, 1, 4), :]
and:
arr[(0, 4, 1, 2, 3), :]
for your expected outputs, respectively.
For a way of generating the indices starting from the two indices you could use the following:
def inner_roll(arr, first, last, axis):
stop = last + 1
indices = list(range(arr.shape[axis]))
indices.insert(first, last)
indices.pop(last + 1)
slicing = tuple(
slice(None) if i != axis else indices
for i, d in enumerate(arr.shape))
return arr[slicing]
For inputs that are relatively small along the axis on which you are operating (such as for the input in the question) this is quite fast.
Comparing it with a slightly polished version of @Mercury's answer to wrap it in a function and to make it work correctly for arbitrary axis
:
import numpy as np
def inner_roll2(arr, first, last, axis):
if first > last:
first, last = last, first
shift = 1
else:
shift = -1
slicing = tuple(
slice(None) if i != axis else slice(first, last + 1)
for i, d in enumerate(arr.shape))
arr[slicing] = np.roll(arr[slicing], shift=shift, axis=axis)
return arr
and getting some timings:
funcs = inner_roll, inner_roll2
for n in (5, 50, 500):
for m in (2, 20, 200):
arr = np.arange(n * m).reshape((n, m))
print(f'({n:<3d}, {m:<3d})', end=' ')
for func in funcs:
results = %timeit -o -q func(arr, 1, 2, 0)
print(f'{func.__name__:>12s} {results.best* 1e6:>7.3f} µs', end=' ')
print()
# (5 , 2 ) inner_roll 5.613 µs inner_roll2 15.393 µs
# (5 , 20 ) inner_roll 5.592 µs inner_roll2 15.468 µs
# (5 , 200) inner_roll 5.916 µs inner_roll2 15.815 µs
# (50 , 2 ) inner_roll 10.117 µs inner_roll2 15.517 µs
# (50 , 20 ) inner_roll 10.360 µs inner_roll2 15.505 µs
# (50 , 200) inner_roll 12.067 µs inner_roll2 15.886 µs
# (500, 2 ) inner_roll 55.833 µs inner_roll2 15.409 µs
# (500, 20 ) inner_roll 57.364 µs inner_roll2 15.319 µs
# (500, 200) inner_roll 194.408 µs inner_roll2 15.731 µs
This indicate that inner_roll()
is the fastest approach for your inputs.
However, inner_roll2()
seems to scale much better with input sizes, and for even modest input sizes, this is already faster than inner_roll()
.
Note that, while inner_roll()
creates a copy, inner_roll2()
works in-place (modifying the input arr
). This behavior can be modified by adding arr = arr.copy()
at the beginning of the body of inner_roll2()
, which would make that function slower (of course) and its timings would then be much more affected by the value of m
(the size of the non-rolled axes).
On the other hand, if you were to do multiple consecutive rolling operations, inner_roll2()
timings would just stack up, while for inner_roll()
you only need to do the expensive part once.