2

To start let's consider a black box where m and n are two variables (where m is a multiple of n) it outputs a 2D matrix of shape (m*n, m*n). Now, it is required to transform this 2D matrix into a 4D matrix with the shape (m, m, n, n). I am unsure the best way to describe this in writing, but the way that the data is structured is that within the 2D (m*n)x(m*n) matrix there exists m lots of (nxn) "tiles" in each direction. Consider an example array a, in this case we have m = 3 and n = 2 so the incoming 2D matrix is 6x6:

print(a)
[[0  1  4  5  8  9 ]
 [2  3  6  7  10 11]
 [12 13 16 17 20 21]
 [14 15 18 19 22 23]
 [24 25 28 29 32 33]
 [26 27 30 31 34 35]]

this is then passed into some function:

b = some_func(a)

to which the required output would be the 4D array:

print(b)
[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]]


 [[[12 13]
   [14 15]]

  [[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]


 [[[24 25]
   [26 27]]

  [[28 29]
   [30 31]]

  [[32 33]
   [34 35]]]]

To put into words, we need to separate out the "nxn" tiles within the larger 2D array. The actual meaning of this situation is that we have mxm matrix, where each entry is actually an nxn matrix, creating a 4D matrix which we can then do following work. This is a highly simplified example for demonstrative purposes in what is a much more complicated system with a lot more going on. In my case there is also an extra axis, m = 256, the entries in the matrix are complex (64-bit) and we are highly concerned about performance however these details are irrelevant to the issue. If it helps at all the case of n = 2 is the only case we are concerned with, however I would hope that there is a more general solution.

I can reasonably conceive of a solution that uses for loops, indexing, modulo arithmetic, etc however this would be drastically inefficient in Python.

Potential Solutions?

  1. The mind instantly jumps to something like np.reshape(), however we cannot use simply use a.reshape(m, m, n, n) as the correct order is not preserved due to the way np.reshape() first ravels the array, as is outlined in a similar issue, however I am strongly convinced that the solution for that issue will not work in this case after much deliberation. I thought a np.reshape(), a np.swapaxes(), another np.reshape() and a np.swapaxes() back might work but alas, even if such a method would work it seems very inefficient. It is entirely conceivable that some concoction of np.reshapes(), np.swapaxes() will provide a solution but I haven't been successful.
  2. A colleague suggested off-hand that np.einsum() is an incredibly powerful and generalizable method for performing matrix operations(?), however I have not been successful with this.
  3. The most likely solution: there is a specific "Pythonic" way of doing things that I am missing - a certain numpy function I am unaware of which will do the exact trick!

I hope that my description of the issue is sufficient. The context surrounding this issue is highly complex (radio astronomy image processing) and would be extraordinarily cumbersome to give the full details of, please take the problem at face value and the initial assumptions as a given when providing any solutions.

Here's the line of code to reproduce the test problem.

a = np.array([[0, 1, 4, 5, 8, 9], [2, 3, 6, 7, 10, 11], [12, 13, 16, 17, 20, 21], [14, 15, 18, 19, 22, 23], [24, 25, 28, 29, 32, 33], [26, 27, 30, 31, 34, 35]])

EDIT: Reproduced the test problem with a little bit of TeX for further clarity: Matrix operation in question

Liam Ryan
  • 23
  • 5
  • 3
    With reshape `a.reshape(3, 2, -1, 2).swapaxes(1,2).reshape(3, 3, 2, 2)`, but I think `np.einsum` could be simpler. From my understanding `reshape` and `swapaxes` are O(1) operations and efficient. Please correct me if I'm wrong. – Michael Szczesny Jan 17 '22 at 11:11
  • Microbenchmark: `reshape` is ~80x faster (87.1µs/1.13µs) than the naive (creating new sub-arrays) solution `np.array([np.hsplit(sub, 3) for sub in np.vsplit(a, 3)])` for your example input. – Michael Szczesny Jan 17 '22 at 11:29
  • 2
    The operations reshape and swapaxes, don't perform copy operations if possible. In cast of the operations described by Michael, a view of the array `a` is created, which in case is of constant complexity. It can be check e.g. by modifying a value in the result, and look if the value is modified in a as well. `np.einsum` indeed is a very powerfull operation, but it does not solve the problem at hand. It is probably useful for the oepration performed on the result of the reshaping, for which I would need more info to comment on. – Jonathan Weine Jan 17 '22 at 11:34
  • @MichaelSzczesny Thank you for your help, `a.reshape(3, 2, -1, 2).swapaxes(1,2).reshape(3, 3, 2, 2)` worked exactly as required. I knew it would be something along those lines, but wouldn't have got the solution myself, especially the first reshape. In my actual situation what was required was `a.reshape(k, m, n, -1, n).swapaxes(2,3).reshape(k, m, m, n, n)`, where *k* is an additional axis associated the data. I'm also getting similar performance results compared to the "naive" method. I also thought of this method but knew that the creation of sub-arrays always brings in inefficiency. – Liam Ryan Jan 17 '22 at 15:02
  • @JonathanWeine Thanks for the input, `np.einsum()` is indeed something I have come across in the codebase I am working with and my colleagues have touted it's ability, I will look further into for future work. After some further digging I believe you are both correct in the fact that the `reshape` and `swapaxes` operations are O(1), I was incorrect in my assumption that such a solution would be inefficient. I'm somewhat new to Python so I'm still learning the workings of such numpy functions and how to extract the most performance. Big thanks once again. – Liam Ryan Jan 17 '22 at 15:09
  • 1
    @LiamRyan Actually you are not creating arrays (just views) by `a.reshape(3, 2, -1, 2).swapaxes(1,2).reshape(3, 3, 2, 2)` if the input array is contiguous. Therefore the performance is independent of array size. For this rather small example you wouldn't see much difference between copying and a view. If you make the exampe 10 000 times larger you definitely see a difference. – max9111 Jan 17 '22 at 15:10
  • @max9111 Thanks for the much needed clarification, the performance characteristics make a lot more sense in that context. The data sets we'll be working with are extremely large so I'm sure that information will be relevant. – Liam Ryan Jan 17 '22 at 15:15
  • 1
    The einsum solution would look like `A=A.reshape(3, 2, -1, 2) #view array as (m,n,m,n) ,B=np.einsum('ijkl->ikjl',A) #the same as the swap axis.` The easiest method is to look at the flags method of the array. If it doesn't own the data it is just a view. – max9111 Jan 17 '22 at 15:18
  • 1
    @max9111, a reshape after a `swapaxes` will (usually) create a copy. `reshape` says to think of it as first doing a ravel. Do the `ravel` after a swap and you'll see a different order of elements (compared to the original). `swapaxes/transpose` work by changing `shape` and `strides`, and don't need to make a copy. But often further processing often does make a copy since the values are not longer 'c' contiguous. – hpaulj Jan 17 '22 at 17:24

2 Answers2

2

I'll try to illustrate the issues discussed in the comments.

A starting array - a reshape of a 1d arange:

In [160]: arr = np.arange(16).reshape(4,4)
In [161]: arr
Out[161]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
In [162]: arr.ravel()
Out[162]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
In [163]: arr.strides
Out[163]: (32, 8)

Further reshape to 4d. Note the ravel is the same. I could also use arr2.__array_interface__ to show the data buffer id.

In [164]: arr1 = arr.reshape(2,2,2,2)
In [165]: arr1.ravel()
Out[165]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
In [166]: arr1.strides            
Out[166]: (64, 32, 16, 8)

It's would be a good idea to test your understanding of the change in strides with change in shape.

Now swap:

In [167]: arr2 = arr1.swapaxes(1,2)
In [168]: arr2
Out[168]: 
array([[[[ 0,  1],
         [ 4,  5]],

        [[ 2,  3],
         [ 6,  7]]],


       [[[ 8,  9],
         [12, 13]],

        [[10, 11],
         [14, 15]]]])
In [169]: arr2.strides
Out[169]: (64, 16, 32, 8)

Still a (2,2,2,2) but strides has changed. This too is a view. But a reshape of this (including a ravel) will make a copy. The elements have been reordered:

In [170]: arr2.ravel()
Out[170]: array([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
In [171]: arr3 = arr2.reshape(4,4)
In [172]: arr3
Out[172]: 
array([[ 0,  1,  4,  5],
       [ 2,  3,  6,  7],
       [ 8,  9, 12, 13],
       [10, 11, 14, 15]])
In [173]: arr3.ravel()
Out[173]: array([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])

We see the same change in strides in a simpler 2d transpose:

In [174]: arr4 = arr.T
In [175]: arr4.strides
Out[175]: (8, 32)
In [176]: arr4.ravel()
Out[176]: array([ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15])

We can make a view ravel by specifying the 'F' column order. Though that may not help with the understanding. Order does not readily extend to higher dimensions, but strides does.

In [177]: arr4.ravel(order='F')
Out[177]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
hpaulj
  • 221,503
  • 14
  • 230
  • 353
  • But it's the ravel at the end, which causes the copy right. I made a small benchmark of Michaels code and the timing was indipendend of array size (I was a bit surprised of this behaviour too).2 – max9111 Jan 17 '22 at 20:25
  • 2
    Reshape followed by swap is a view, and size independent. It's the subsequent operations that might require the copy. – hpaulj Jan 17 '22 at 22:08
  • Reshape followed by swap is a view, and size independent. It's the subsequent operations that might require the copy. – hpaulj Jan 17 '22 at 22:08
0

I am not sure if this is better/faster than hpaulj answer, I guess will be similar in performance. Please check yourself on big arrays

a.reshape((m, m * n * n)).reshape(m, n, m * n).transpose(0, 2, 1).reshape(m, m, n, n).transpose(0, 1, 3, 2)
dankal444
  • 3,172
  • 1
  • 23
  • 35