3

I have a 3D matrix in python as the following:

import numpy as np

a = np.ones((2,2,3))
a[0,0,0] = 2
a[0,0,1] = 3
a[0,0,2] = 4

I want to convert this 3D matrix to a set of 2D matrices. I have tried np.reshape but it did not solve my problem. The final shape I am interested in is the following cascaded vesrsion:

 [[ 2.  1.  3.  1.  4.  1.]
  [ 1.  1.  1.  1.  1.  1.]]

However, np.reshape gives me the following

 [[ 2.  3.  4.  1.  1.  1.]
  [ 1.  1.  1.  1.  1.  1.]]

How can I solve this?

Divakar
  • 218,885
  • 19
  • 262
  • 358
A.M.
  • 1,757
  • 5
  • 22
  • 41

1 Answers1

2

Use transpose alongwith reshape -

a.transpose([0,2,1]).reshape(a.shape[0],-1)

Or use swapaxes that does the same job as transpose alongwith reshape -

a.swapaxes(2,1).reshape(a.shape[0],-1)

Sample run -

In [66]: a
Out[66]: 
array([[[ 2.,  3.,  4.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]])

In [67]: a.transpose([0,2,1]).reshape(a.shape[0],-1)
Out[67]: 
array([[ 2.,  1.,  3.,  1.,  4.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.]])

In [68]: a.swapaxes(2,1).reshape(a.shape[0],-1)
Out[68]: 
array([[ 2.,  1.,  3.,  1.,  4.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.]])
Divakar
  • 218,885
  • 19
  • 262
  • 358