To start let's consider a black box where m and n are two variables (where m is a multiple of n) it outputs a 2D matrix of shape (m*n, m*n). Now, it is required to transform this 2D matrix into a 4D matrix with the shape (m, m, n, n). I am unsure the best way to describe this in writing, but the way that the data is structured is that within the 2D (m*n)x(m*n) matrix there exists m lots of (nxn) "tiles" in each direction. Consider an example array a, in this case we have m = 3 and n = 2 so the incoming 2D matrix is 6x6:
print(a)
[[0 1 4 5 8 9 ]
[2 3 6 7 10 11]
[12 13 16 17 20 21]
[14 15 18 19 22 23]
[24 25 28 29 32 33]
[26 27 30 31 34 35]]
this is then passed into some function:
b = some_func(a)
to which the required output would be the 4D array:
print(b)
[[[[ 0 1]
[ 2 3]]
[[ 4 5]
[ 6 7]]
[[ 8 9]
[10 11]]]
[[[12 13]
[14 15]]
[[16 17]
[18 19]]
[[20 21]
[22 23]]]
[[[24 25]
[26 27]]
[[28 29]
[30 31]]
[[32 33]
[34 35]]]]
To put into words, we need to separate out the "nxn" tiles within the larger 2D array. The actual meaning of this situation is that we have mxm matrix, where each entry is actually an nxn matrix, creating a 4D matrix which we can then do following work. This is a highly simplified example for demonstrative purposes in what is a much more complicated system with a lot more going on. In my case there is also an extra axis, m = 256, the entries in the matrix are complex (64-bit) and we are highly concerned about performance however these details are irrelevant to the issue. If it helps at all the case of n = 2 is the only case we are concerned with, however I would hope that there is a more general solution.
I can reasonably conceive of a solution that uses for loops, indexing, modulo arithmetic, etc however this would be drastically inefficient in Python.
Potential Solutions?
- The mind instantly jumps to something like np.reshape(), however we cannot use simply use a.reshape(m, m, n, n) as the correct order is not preserved due to the way np.reshape() first ravels the array, as is outlined in a similar issue, however I am strongly convinced that the solution for that issue will not work in this case after much deliberation. I thought a np.reshape(), a np.swapaxes(), another np.reshape() and a np.swapaxes() back might work but alas, even if such a method would work it seems very inefficient. It is entirely conceivable that some concoction of np.reshapes(), np.swapaxes() will provide a solution but I haven't been successful.
- A colleague suggested off-hand that np.einsum() is an incredibly powerful and generalizable method for performing matrix operations(?), however I have not been successful with this.
- The most likely solution: there is a specific "Pythonic" way of doing things that I am missing - a certain numpy function I am unaware of which will do the exact trick!
I hope that my description of the issue is sufficient. The context surrounding this issue is highly complex (radio astronomy image processing) and would be extraordinarily cumbersome to give the full details of, please take the problem at face value and the initial assumptions as a given when providing any solutions.
Here's the line of code to reproduce the test problem.
a = np.array([[0, 1, 4, 5, 8, 9], [2, 3, 6, 7, 10, 11], [12, 13, 16, 17, 20, 21], [14, 15, 18, 19, 22, 23], [24, 25, 28, 29, 32, 33], [26, 27, 30, 31, 34, 35]])
EDIT: Reproduced the test problem with a little bit of TeX for further clarity: