I'll try to illustrate the concern that @Divaker
brings up.
In [522]: arr = np.arange(2*2*3*4).reshape(2,2,3,4)
In [523]: arr
Out[523]:
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]],
[[[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35]],
[[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]]])
4 is the inner most dimension, so it displays the array as 3x4 blocks. And if you pay attention to spaces and [] you'll see there are 2x2 blocks.
Notice what happens when we use the reshape
:
In [524]: arr1 = arr.reshape(2,2,-1)
In [525]: arr1
Out[525]:
array([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]],
[[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]]])
Now it is 2 2x12 blocks. You can do anything to those 12 element rows, and reshape them back to 3x4 blocks
In [526]: arr1.reshape(2,2,3,4)
Out[526]:
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
...
But I could also split this array on the last dimension. np.split
can do it, but a list comprehension is easier to understand:
In [527]: alist = [arr[...,i] for i in range(4)]
In [528]: alist
Out[528]:
[array([[[ 0, 4, 8],
[12, 16, 20]],
[[24, 28, 32],
[36, 40, 44]]]),
array([[[ 1, 5, 9],
[13, 17, 21]],
[[25, 29, 33],
[37, 41, 45]]]),
array([[[ 2, 6, 10],
[14, 18, 22]],
[[26, 30, 34],
[38, 42, 46]]]),
array([[[ 3, 7, 11],
[15, 19, 23]],
[[27, 31, 35],
[39, 43, 47]]])]
This contains 4 (2,2,3) arrays. Note that the 3 element rows display as columns in the 4d display.
I can reform into a 4d array with np.stack
(which is like np.array
, but gives more control of how the arrays are joined):
In [529]: np.stack(alist, axis=-1)
Out[529]:
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
...
[[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]]])
==========
The split
equivalent is [x[...,0] for x in np.split(arr, 4, axis=-1)]
. Without the indexing split produces (2, 2, 3, 1) arrays.
collapse_dims
produces (for my example):
In [532]: np.rollaxis(arr,-1,2).reshape(arr.shape[0],arr.shape[1],-1)
Out[532]:
array([[[ 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11],
[12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23]],
[[24, 28, 32, 25, 29, 33, 26, 30, 34, 27, 31, 35],
[36, 40, 44, 37, 41, 45, 38, 42, 46, 39, 43, 47]]])
A (2,2,12) array, but with the elements in rows in a different order. It does a transpose on the inner 2 dimensions before flattening.
In [535]: arr[0,0,:,:].ravel()
Out[535]: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
In [536]: arr[0,0,:,:].T.ravel()
Out[536]: array([ 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11])
Restoring that back to the original order requires another roll or transpose
In [542]: arr2.reshape(2,2,4,3).transpose(0,1,3,2)
Out[542]:
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
....
[[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]]])