1

Consider the following code

import numpy as np
z = np.zeros((3,5,10,100))
indices = np.array([8, 0, 6, 1])
print(z[:,:,indices,:].shape)
print(z[1,:,indices,:].shape)

The output is as follows:

(3, 5, 4, 100)
(4, 5, 100)

I want to assign z[1,:,indices,:] = some_array where some_array has shape (5,4,100) but that assignment throws an error as there is shape mismatch.
I am confused about the second output (shape of z[1,:,indices,:]). I thought it should be (5,4,100). Why are the first 2 axes getting switched? Is this a bug or is there any explanation why this is the correct behavior?

I got my code to go the desired thing by changing it to:

for i,index in enumerate(indices):
    z[1,:,index,:] = some_array[i]
Black Jack 21
  • 315
  • 4
  • 19

0 Answers0