24

While implementing a Kronecker-product for pedagogical reasons (without using the obvious and readily available np.kron()), I obtained a 4 dimensional array as an intermediate result, which I've to reshape to get the final result.

But, I still can't wrap my head around reshaping these high dimensional arrays. I have this 4D array:

array([[[[ 0,  0],
         [ 0,  0]],

        [[ 5, 10],
         [15, 20]]],


       [[[ 6, 12],
         [18, 24]],

        [[ 7, 14],
         [21, 28]]]])

This is of shape (2, 2, 2, 2) and I'd like to reshape it to (4,4). One might think that this is obvious to do with

np.reshape(my4darr, (4,4))

But, the above reshape does not give me the expected result which is:

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

As you can see, all the elements in the expected result are present in the 4D array. I just can't get the hang of doing the reshape correctly as needed. In addition to the answer, some explanation of how to do the reshape for such high dimensional arrays would be really helpful. Thanks!

cs95
  • 379,657
  • 97
  • 704
  • 746
kmario23
  • 57,311
  • 13
  • 161
  • 150

3 Answers3

46

General idea for nd to nd transformation

The idea with such nd to nd transformation is using just two things -

Permute axes : To get the order such that the flattened version corresponds to the flattened version of output. So, if you somehow end up using it twice, look again because you shouldn't.

Reshape : To split the axes or bring the final output to the desired shape. Splitting axes is needed mostly at the start, when the input is of lower-dim and we are needed to split into blocks. Again, you shouldn't need this more than twice.

Hence, generally we would have three steps :

    [ Reshape ]      --->  [ Permute axes ]   --->  [ Reshape ]

 Create more axes             Bring axes             Merge axes
                          into correct order

Back-tracking method

The safest way to solve, given the input and output is through, what one could call as the back-tracking method, i.e. split the axes of the input (when going from smaller nd to bigger nd) or split the axes of the output (when going from bigger nd to smaller nd). The idea with the splitting is to bring the number of dims of the smaller nd one same as the bigger nd one. Then, study the strides of the output and match it up against the input to get the required permute order. Finally, a reshape (default way or C order) might be needed at the end, if the final one is a smaller nd one, to merge axes.

If both input and output are of same number of dims, then we would need to split both and break into blocks and study their strides against each other. In such cases, we should have the additional input parameter of block sizes, but that's probably off-topic.

Example

Let's use this specific case to demonstrate how to apply those strategies. In here, the input is 4D, while output is 2D. So, most probably, we won't need reshape to split. So, we need to start with permuting axes. Since, the final output is not 4D, but a 2D one, we would need a reshape at the end.

Now, the input here is :

In [270]: a
Out[270]: 
array([[[[ 0,  0],
         [ 0,  0]],

        [[ 5, 10],
         [15, 20]]],


       [[[ 6, 12],
         [18, 24]],

        [[ 7, 14],
         [21, 28]]]])

The expected output is :

In [271]: out
    Out[271]: 
    array([[ 0,  5,  0, 10],
           [ 6,  7, 12, 14],
           [ 0, 15,  0, 20],
           [18, 21, 24, 28]])

Also, this is a bigger nd to smaller nd transformation, so the back-tracking method would involve, splitting the output and studying its strides and matching up against the corresponding values in input :

                    axis = 3
                   ---      -->          
                                        
                    axis = 1                    
                   ------>           
axis=2|  axis=0|   [ 0,  5,  0, 10],        

               |   [ 6,  7, 12, 14],
               v  
      |            [ 0, 15,  0, 20],
      v
                   [18, 21, 24, 28]])

Hence, the permuted order needed is (2,0,3,1) :

In [275]: a.transpose((2, 0, 3, 1))
Out[275]: 
array([[[[ 0,  5],
         [ 0, 10]],

        [[ 6,  7],
         [12, 14]]],


       [[[ 0, 15],
         [ 0, 20]],

        [[18, 21],
         [24, 28]]]])

Then, simply reshape to the expected shape :

In [276]: a.transpose((2, 0, 3, 1)).reshape(4,4)
Out[276]: 
array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

More examples

I dug up my history and found few Q&As based on nd to nd transformations. These could serve as other example cases, albeit with lesser explanation (mostly). As mentioned earlier, at most two reshapes and at most one swapaxes/transpose did the job everywhere. They are listed below :

Community
  • 1
  • 1
Divakar
  • 218,885
  • 19
  • 262
  • 358
11

It seems like you're looking for a transpose followed by a reshape.

x.transpose((2, 0, 3, 1)).reshape(np.prod(x.shape[:2]), -1)

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

To help you understand why a transposition is needed, let's analyse your incorrectly shaped output (obtained by a single reshape call) to understand why it is incorrect.

A simple 2D reshaped version of this result (without any transposition) looks like this -

x.reshape(4, 4)

array([[ 0,  0,  0,  0],
       [ 5, 10, 15, 20],
       [ 6, 12, 18, 24],
       [ 7, 14, 21, 28]])

Now consider this output with respect to your expected output -

array([[ 0,  5,  0, 10],
       [ 6,  7, 12, 14],
       [ 0, 15,  0, 20],
       [18, 21, 24, 28]])

You'll notice that your actual result is obtained by a Z-like traversal of your incorrectly shaped output -

start
    | /|     /| /|
    |/ |    / |/ |
      /    /    / 
     /    /    /
    | /| /    | /|
    |/ |/     |/ |
                 end

This implies that you must move over the array in varying strides to get your actual result. In conclusion, a simple reshape is not enough. You must transpose the original array, in such a manner that these Z-like elements are made to be contiguous to each other, such that a subsequent reshape call gives you the output you need.

To understand how to transpose correctly, you should trace the elements along the input and figure out what axes you need to jump to get to each one in the output. The transposition follows accordingly. Divakar's answer does a sterling job of explaining this.

cs95
  • 379,657
  • 97
  • 704
  • 746
  • ... or just flip it during the transpose > – juanpa.arrivillaga Dec 26 '17 at 10:37
  • @juanpa.arrivillaga Why'd you delete? It looked right. – cs95 Dec 26 '17 at 10:38
  • 2
    Because it is definitely more elegant to just `.transpose(2, 0, 3, 1)` instead of `.transpose(0,2,1,3)` and *then* reshape, taking the array in fortran-order... – juanpa.arrivillaga Dec 26 '17 at 10:39
  • @cᴏʟᴅsᴘᴇᴇᴅ Thanks! So, there's isn't a straightforward way to do it using only `reshape` without `np.prod` et al? – kmario23 Dec 26 '17 at 10:45
  • @kmario23 Unfortunately not. `reshape` only changes the view (number of dimensions). Judging by your input and expected output, you required a transposition, rather than a simple reshape. – cs95 Dec 26 '17 at 10:46
  • @kmario23 Hmm, I've added some explanation, I hope that's what you're looking for. – cs95 Dec 26 '17 at 11:00
  • 1
    @cᴏʟᴅsᴘᴇᴇᴅ Needed to use your solution to explain for a generic case. Hope that's okay. – Divakar Dec 26 '17 at 11:34
  • @Divakar No problems at all. Thanks for weighing in with your expertise :-) – cs95 Dec 26 '17 at 11:38
  • @cᴏʟᴅsᴘᴇᴇᴅ I'm accepting Divakar's answer since it's much more straightforward and simple. But, your explanations are good as well! – kmario23 Dec 26 '17 at 12:19
  • 1
    @kmario23 No problem. My explanation follows a different angle from Divakar's, because I wanted to correct your misconception that a simple reshape was enough. For this purpose, I went with analysing the incorrectly reshaped output rather than the original input. I have no complaints with the acceptance, his answer is gold standard. – cs95 Dec 26 '17 at 12:21
0

The Divarkar's answer is great, though sometimes it is easier for me just to check all possible cases which transpose and reshape cover.

For example, the following code

n, m = 4, 2
arr = np.arange(n*n*m*m).reshape(n,n,m,m)
for permut in itertools.permutations(range(4)):
    arr2 = (arr.transpose(permut)).reshape(n*m, n*m)
    print(permut, arr2[0])

gives me all what one can get from 4-dimensional array using transpose + reshape. Since, I know how the output should look like, I will just pick the permutation that showed me the correct answer. If I didn't get what I wanted, then transpose + reshape is not general enough to cover my case and I have to do something more complicated.

cheyp
  • 391
  • 2
  • 4
  • 12