I can make it simpler coding wise with np.diag_indices
. I don't about efficiency relative to a strided solution (if possible). Let's see if I can simplify my development history enough
First the indices
In [2]: np.diag_indices(2)
Out[2]: (array([0, 1]), array([0, 1]))
Simpler start; we don't need 2 dimensions at the start, those can be changed with reshape. We probably don't need an ending dimension, but I'll leave that for now:
In [3]: A1=np.arange(12).reshape(4,3)
Now build the reference solution:
In [4]: A2=np.zeros((4,2,2,3),int)
In [5]: A2[:,0,0,:]=A1
In [6]: A2[:,1,1,:]=A1
In [7]: A2
Out[7]:
array([[[[ 0, 1, 2],
[ 0, 0, 0]],
[[ 0, 0, 0],
[[[ 3, 4, 5],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 3, 4, 5]]],
...
[[[ 9, 10, 11],
[ 0, 0, 0]],
[ 0, 1, 2]]],
...
[[ 0, 0, 0],
[ 9, 10, 11]]]])
Alternative:
In [8]: A3=np.zeros((4,2,2,3),int)
In [9]: i,j=np.diag_indices(2)
In [10]: A3[:,i,j,:]=A1
...
ValueError: shape mismatch: value array of shape (4,3) could not be broadcast to indexing result of shape (2,4,3)
Shape mismatch the first attempt
In [12]: A2[:,i,j,:]
Out[12]:
array([[[ 0, 1, 2],
[ 0, 1, 2]],
[[ 3, 4, 5],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 6, 7, 8]],
[[ 9, 10, 11],
[ 9, 10, 11]]])
In [13]: A2[:,i,j,:].shape
Out[13]: (4, 2, 3)
We need to modify A1
so it can broadcast to the destination slot.
In [14]: A1.shape
Out[14]: (4, 3)
In [15]: A3[:,i,j,:] = A1[:,None,:]
In [16]: np.allclose(A2,A3)
Out[16]: True
A2[...,i,j,:] = A1[...,None,:]
should handle your example.
An even simpler version, starts with 1d array, expending to 3d
In [21]: a1=np.arange(3)
In [22]: a3=np.zeros((2,2,3),int)
In [23]: a3[...,i,j,:]=a1[...,None,:]
In [24]: a3[i,j,:]=a1 # equivalent since a1[None,:] is automatic
In [25]: a3
Out[25]:
array([[[0, 1, 2],
[0, 0, 0]],
[[0, 0, 0],
[0, 1, 2]]])
a3
doesn't have a repeated pattern of a1
values; or does it?
In [36]: a3.flatten()
Out[36]: array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2])
As you discovered, it's easy to fill in all the slots with as_strides
, but hard to fill just the diagonals:
In [46]: ast(a1,shape=a3.shape, strides=(0,0,4))
Out[46]:
array([[[0, 1, 2],
[0, 1, 2]],
[[0, 1, 2],
[0, 1, 2]]])