Consider the following code
import numpy as np
z = np.zeros((3,5,10,100))
indices = np.array([8, 0, 6, 1])
print(z[:,:,indices,:].shape)
print(z[1,:,indices,:].shape)
The output is as follows:
(3, 5, 4, 100)
(4, 5, 100)
I want to assign z[1,:,indices,:] = some_array
where some_array
has shape (5,4,100)
but that assignment throws an error as there is shape mismatch.
I am confused about the second output (shape of z[1,:,indices,:]
). I thought it should be (5,4,100)
. Why are the first 2 axes getting switched? Is this a bug or is there any explanation why this is the correct behavior?
I got my code to go the desired thing by changing it to:
for i,index in enumerate(indices):
z[1,:,index,:] = some_array[i]