I'm trying to take the hilbert schmidt inner product of two matices.
For two matrices A, B this operation requires the matrix product of the Hermitian conjugate of A times B followed by the trace (summing down the diagonal). As you only need the diagonal entries, it is pointless to find the full matrix product, as the off diagonal terms are not needed.
In effect one needs to compute:
(† )=∑_{} (_* _)
where i indexes rows and j indexes columns.
What is the fastest way to do this for sparse matrices? I found the following similar article:
What is the best way to compute the trace of a matrix product in numpy?
Currently, I am doing:
def hilbert_schmidt_inner_product(mat1, mat2):
## find nonzero ij indices of each matrix
mat1_ij = set([tuple(x) for x in np.array(list(zip(*mat1.nonzero())))])
mat2_ij = set([tuple(x) for x in np.array(list(zip(*mat2.nonzero())))])
## find common ij indices between these that are both nonzero
common_ij = np.array(list(mat1_ij & mat2_ij))
## select common i,j indices from both (now will be 1D array)
mat1_survied = np.array(mat1[common_ij[:,0], common_ij[:,1]])[0]
mat2_survied = np.array(mat2[common_ij[:,0], common_ij[:,1]])[0]
## multiply the nonzero ij common elements (1D dot product!)
trace = np.dot(mat1_survied.conj(),mat2_survied)
return trace
However this is slower than:
import numpy as np
sum((mat1.conj().T@mat2).diagonal())
which does the full matix product before taking the trace and thus does pointless operations to find off-diagonal elements. Is there a better way of doing this?
I am using the following to benchmark:
import numpy as np
from scipy.sparse import rand
Dimension = 2**12
A = rand(Dimension, Dimension, density=0.001, format='csr')
B = rand(Dimension, Dimension, density=0.001, format='csr')
running a few tests, I find:
%timeit hilbert_schmidt_inner_product(A,B)
49.2 ms ± 3.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit sum((A.conj().T@B).diagonal())
1.48 ms ± 32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum('ij,ij->', A.conj().todense(), B.todense())
53.9 ms ± 2.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)