Let A be an (nxm)-matrix and M an (mxm)-matrix. Writing tr() for the trace of a matrix, I need to compute tr(AM(A^T)). However, the final trace operation throws away most of the computation. Can I use numpy's or pytorch's broadcasting rules to compute only the necessary diagonal of AM(A^T)?
Update: Here is my solution to compute the diagonal in PyTorch:
torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)