18

If I have numpy arrays A and B, then I can compute the trace of their matrix product with:

tr = numpy.linalg.trace(A.dot(B))

However, the matrix multiplication A.dot(B) unnecessarily computes all of the off-diagonal entries in the matrix product, when only the diagonal elements are used in the trace. Instead, I could do something like:

tr = 0.0
for i in range(n):
    tr += A[i, :].dot(B[:, i])

but this performs the loop in Python code and isn't as obvious as numpy.linalg.trace.

Is there a better way to compute the trace of a matrix product of numpy arrays? What is the fastest or most idiomatic way to do this?

amcnabb
  • 2,161
  • 1
  • 16
  • 24

3 Answers3

15

You can improve on @Bill's solution by reducing intermediate storage to the diagonal elements only:

from numpy.core.umath_tests import inner1d

m, n = 1000, 500

a = np.random.rand(m, n)
b = np.random.rand(n, m)

# They all should give the same result
print np.trace(a.dot(b))
print np.sum(a*b.T)
print np.sum(inner1d(a, b.T))

%timeit np.trace(a.dot(b))
10 loops, best of 3: 34.7 ms per loop

%timeit np.sum(a*b.T)
100 loops, best of 3: 4.85 ms per loop

%timeit np.sum(inner1d(a, b.T))
1000 loops, best of 3: 1.83 ms per loop

Another option is to use np.einsum and have no explicit intermediate storage at all:

# Will print the same as the others:
print np.einsum('ij,ji->', a, b)

On my system it runs slightly slower than using inner1d, but it may not hold for all systems, see this question:

%timeit np.einsum('ij,ji->', a, b)
100 loops, best of 3: 1.91 ms per loop
Community
  • 1
  • 1
Jaime
  • 65,696
  • 17
  • 124
  • 159
  • On my machine, the `inner1d` and `einsum` approaches are basically indistinguishable in speed. – amcnabb Sep 17 '13 at 16:48
  • 1
    One would think that `np.einsum` would have a slight edge, because it doesn't have to store all the diagonal elements before adding them, it keeps a running sum. And it uses SIMD instructions, so you should get a factor of 2 or 4 improvement over `inner1d`. But they perform identically on my system as well, even for larger data. – Jaime Sep 17 '13 at 16:51
  • By the way, is `umath_tests` a stable public API? The name makes it sound private, and it seems to be less documented than the other parts of numpy. – amcnabb Sep 17 '13 at 16:52
  • @amcnabb Thats interesting- what version of numpy are you using? Examining the source code it looks like `inner1d` is a C++ definition that is written to make use of `SSE`. See [here](https://github.com/numpy/numpy/blob/3abd8699dc3c71e389356ca6d80a2cb9efa16151/numpy/core/src/umath/umath_tests.c.src). Could help answer the `einsum` question. – Daniel Sep 17 '13 at 16:54
  • @Ophion, I'm running "numpy-1.7.1-5.fc19.x86_64" on Fedora, for what it's worth. – amcnabb Sep 17 '13 at 16:57
  • 1
    @amcnabb In numpy 1.8 `inner1d` is going to be included in `numpy.linalg._umath_linalg`, not sure if it will stay in `numpy.core.umath_tests`. It may move around, but I think there is a clear intention to keep it and expose it moer and more. – Jaime Sep 17 '13 at 17:43
  • @Ophion Where do you see the use of `SSE` in the source code for `inner1d`? If there is not an `#include "emmintrin.h"` or something similar, you don't have access to the `SSE` functions. See the beginning of [einsum.c.src](https://github.com/numpy/numpy/blob/31a550189371ed21f8d38edae02f71f18a729741/numpy/core/src/multiarray/einsum.c.src) for comparison. – Jaime Sep 17 '13 at 18:00
  • @Jaime I know `icc` can automatically generate code using `SSE2` without calling `__m128 ...` operations explicitly for intel CPU's. Looks like GCC has this future [also](http://gcc.gnu.org/projects/tree-ssa/vectorization.html). Although this looks outdated. Is there something that stops this? – Daniel Sep 17 '13 at 18:39
  • Yes, that's a possibility, and it could explain differences on my ssytem, which has been compiled with MSVC, which I don't think does such optimizations. – Jaime Sep 17 '13 at 19:33
10

From wikipedia you can calculate the trace using the hadamard product (element-wise multiplication):

# Tr(A.B)
tr = (A*B.T).sum()

I think this takes less computation than doing numpy.trace(A.dot(B)).

Edit:

Ran some timers. This way is much faster than using numpy.trace.

In [37]: timeit("np.trace(A.dot(B))", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[38]: 8.6434469223022461

In [39]: timeit("(A*B.T).sum()", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[40]: 0.5516049861907959
learner
  • 3,168
  • 3
  • 18
  • 35
wflynny
  • 18,065
  • 5
  • 46
  • 67
  • This seems to be faster than `numpy.trace` even if the matrices aren't symmetrical (e.g. if `A` is 1000x100 and `B` is 100x1000, or vice versa). – amcnabb Sep 17 '13 at 16:27
  • 4
    You might want to mention that A and B must be `ndarrays` exclusively so its not confusing. Also it should be noted that the timings are heavily influenced by what kind of BLAS your numpy is linked to. For additional speed consider using the expression `np.einsum('ij,ji->',A,B)`. – Daniel Sep 17 '13 at 16:30
  • 1
    @Ophion Had included that in my answer before reading this... We may have another case of the mysterious slowness of `np.einsum` on my system... – Jaime Sep 17 '13 at 16:45
0

Note that one slight variant is to take the dot product of the vectorized matrices. In python, vectorization is done using .flatten('F'). It's slightly slower than taking the sum of the Hadamard product, on my computer, so it's a worse solution than wflynny's , but I think it's kind of interesting, since it can be more intuitive, in some situations, in my opinion. For example, personally I find that for the matrix normal distribution, the vectorized solution is easier for me to understand.

Speed comparison, on my system:

import numpy as np
import time

N = 1000

np.random.seed(123)
A = np.random.randn(N, N)
B = np.random.randn(N, N)

tart = time.time()
for i in range(10):
    C = np.trace(A.dot(B))
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = A.flatten('F').dot(B.T.flatten('F'))
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = (A.T * B).sum()
print(time.time() - start, C)

start = time.time()
for i in range(10):
    C = (A * B.T).sum()
print(time.time() - start, C)

Result:

6.246593236923218 -629.370798672
0.06539678573608398 -629.370798672
0.057890892028808594 -629.370798672
0.05709719657897949 -629.370798672
Hugh Perkins
  • 7,975
  • 7
  • 63
  • 71