I have an array consisting of the lower diagonal elements of a single matrix. I can convert that to the full matrix by following the methods from How to convert triangle matrix to square in NumPy?. For a single matrix, the example looks like:
# Create the lower diagonal elements of a 6x6 matrix.
ld = np.arange(21)
# Create full 6x6 matrix
x = np.zeros((6,6))
# Stuff lower triangular values into it
x[np.tril_indices(6)] = ld
# Populate upper triangular elements
x = x + x.T
# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x, ld[diag_idx])
print(x)
and we get the expected full matrix
[[ 0. 1. 3. 6. 10. 15.]
[ 1. 2. 4. 7. 11. 16.]
[ 3. 4. 5. 8. 12. 17.]
[ 6. 7. 8. 9. 13. 18.]
[10. 11. 12. 13. 14. 19.]
[15. 16. 17. 18. 19. 20.]]
Now I want to extend this to having an array of N sets of lower diagonal elements and want to get back an array of N full matrices. The former has shape (N, 21) and the latter (N, 6, 6). I expanding the single matrix example into one containing 2 matrices
# Two sets of lower diagonal elements
ld = np.arange(2*21).reshape(2, 21)
# Two sets of full 6x6 matrices
x = np.zeros((ld.shape[0], 6, 6))
# Find the lower triangular indices of each row and stuff them with the
# values from the corresponding row in the lower diagonal array
x[:, np.tril_indices(6)] = ld[:]
# Populate upper triangular elements
x[:] = x[:] + x[:].T
# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x[:], ld[:][diag_idx])
but I get a shape mismatch on the line x[:, np.tril_indices(6)] = ld[:]
ValueError: shape mismatch: value array of shape (2,21) could not be broadcast to indexing result of shape (2,2,21,6)
I could do a normal Python loop over the N sets of lower diagonal values, but was trying to do it all via Numpy. Any suggestions on where I've gone wrong with my indexing?
The expected values in X are:
[[[ 0. 1. 3. 6. 10. 15.]
[ 1. 2. 4. 7. 11. 16.]
[ 3. 4. 5. 8. 12. 17.]
[ 6. 7. 8. 9. 13. 18.]
[10. 11. 12. 13. 14. 19.]
[15. 16. 17. 18. 19. 20.]],
[[21., 22., 24., 27., 31., 36.],
[22., 23., 25., 28., 32., 37.],
[24., 25., 26., 29., 33., 38.],
[27., 28., 29., 30., 34., 39.],
[31., 32., 33., 34., 35., 40.],
[36., 37., 38., 39., 40., 41.]]]