I have an array a whose size is torch.Size([330, 330, 36])
The original data structure is [330, 6, 330, 6]
The meaning is:
I have 330 atoms in the system, each atom has 6 orbitals
I want to know the interactions between all atoms and all orbitals.
I want to perform these operations:
(1) a.reshape(330,330,6,6).
permute(0,2,1,3).reshape(1980, 1980)
convert the matrix to (330 x 6) x (330 x 6)
(2) torch.sum(torch.diag(b@b)[1:6])
perform a matmul operation and sum the diagonal elements 1-5
I want to know if there is any method to perform matmul operation without reshaping 330x330x36 matrix.
Thanks a lot.
(1) a.reshape(330,330,6,6).permute(0,2,1,3).reshape(1980,1980)
(2) torch.sum(torch.diag(b@b)[1:6])
What if I have a list of matrices, how to do matmul operations in a single command?