4

Let a be some (not necessarily one-dimensional) NumPy array with n * m elements along its last axis. I wish to "split" this array along its last axis so that I take every n'th element starting from 0 up until n.

To be explicit let a have shape (k, n * m) then I wish to construct the array of shape (n, k, m)

np.array([a[:, i::n] for i in range(n)])

my problem is that though this indeed return the array that I seek, I still feel that there might be a more efficient and neat NumPy routine for this.

Cheers!

  • You cannot do that: Numpy does not support jagged arrays, that is, array containing arrays of different size (you can cheat with Numpy array containing Numpy array but this is not really a unique Numpy array, and it is very slow). The provided code fails with n=2, m=5 and k=3. – Jérôme Richard Nov 18 '22 at 22:59
  • Sorry, my bad! The range was supposed stop at n and not m. The resulting array should have shape (n, k, m). I've edited the original question to reflect this. – William Crawford Nov 19 '22 at 09:26

3 Answers3

1

This is hard to write a faster Numpy implementation. One efficient solution is to use Numba so to speed this up. That being said, the memory access pattern can be the main reason why the code is slow on relatively large matrices. As a result, one need to care about the iteration order so the accesses can be relatively cache-friendly. Moreover, for large arrays, it can be a good idea to use multiple threads so to better mitigate the overhead coming from the relatively-high memory latency (due to the memory access pattern). Here is an implementation:

import numba as nb

# The first call is slower due to the build.
# Please consider specifying the signature of the function (ie. input types)
# to precompile the function ahead of time.
@nb.njit # Use nb.njit(parallel=True) for the parallel version
def compute(arr, n):
    k, m = arr.shape[0], arr.shape[1] // n
    assert arr.shape[1] == n * m

    out = np.empty((n, k, m), dtype=arr.dtype)

    # Use nb.prange for the parallel version
    for i2 in range(k):
        for i1 in range(n):
            outView = out[i1, i2]
            inView = a[i2]
            cur = i1
            for i3 in range(m):
                outView[i3] = inView[cur]
                cur += n

    return out

Here is the results on my machine with a i5-9600KF processor (6-cores) for k=37, n=42, m=53 and a.dtype=np.int32:

John Zwinck's solution:    986.1 µs
Initial implementation:     91.7 µs
Sequential Numba:           62.9 µs
Parallel Numba:             14.7 µs
Optimal lower-bound:        ~7.0 µs
Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
0

I think this does what you want, without loops. I tested for 2D inputs, it may need some adjustments for more dimensions.

indexes = np.arange(0, a.size*n, n) + np.repeat(np.arange(n), a.size/n)
np.take(a, indexes, mode='wrap').reshape(n, a.shape[0], -1)

In my testing it is a bit slower than your original list solution.

John Zwinck
  • 239,568
  • 38
  • 324
  • 436
0

If I'm not mistaken this does what you expect, and fast:

a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)

Example:

import numpy as np
k=5; n=3; m=4
a = np.arange(k*n*m).reshape(k, n*m)
a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)
"""
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
       [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

is transformed into:

array([[[ 0,  3,  6,  9],
        [12, 15, 18, 21],
        [24, 27, 30, 33],
        [36, 39, 42, 45],
        [48, 51, 54, 57]],

       [[ 1,  4,  7, 10],
        [13, 16, 19, 22],
        [25, 28, 31, 34],
        [37, 40, 43, 46],
        [49, 52, 55, 58]],

       [[ 2,  5,  8, 11],
        [14, 17, 20, 23],
        [26, 29, 32, 35],
        [38, 41, 44, 47],
        [50, 53, 56, 59]]])
"""

Timing:

from time import time
k=37; n=42; m=53
a = np.arange(k*n*m).reshape(k, n*m)

start = time()
for _ in range(1_000_000):
    res = a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0,1)
time() - start

# 0.95 s per 1 mil repetitions
isCzech
  • 313
  • 1
  • 1
  • 7