3

Let A be an (nxm)-matrix and M an (mxm)-matrix. Writing tr() for the trace of a matrix, I need to compute tr(AM(A^T)). However, the final trace operation throws away most of the computation. Can I use numpy's or pytorch's broadcasting rules to compute only the necessary diagonal of AM(A^T)?

Update: Here is my solution to compute the diagonal in PyTorch:

torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)

ASML
  • 199
  • 12
  • Does this answer your question? [What is the best way to compute the trace of a matrix product in numpy?](https://stackoverflow.com/questions/18854425/what-is-the-best-way-to-compute-the-trace-of-a-matrix-product-in-numpy) – iacob Mar 23 '21 at 13:25

1 Answers1

0

You will have to compute at least one of the two matrix products. Subsequently you can use one of the answers here: What is the best way to compute the trace of a matrix product in numpy?

Pim
  • 166
  • 1
  • 7