1

Problem Setup: I have a 3D (spatial) grid of data of size n1,n2,n3=nx,ny,nz with possibly nz=1 for a or b. Each point in that grid has a vector (a) of data of size NDIM (=4 typically) per-gridpoint and another matrix (b) of size NDIMxNDIM per-gridpoint. I want to compute (per-point) things like a.b or b.a most efficiently for both memory and CPU.

Essentially, I'd like to generalize A loopless 3D matrix multiplication in python . I seem to have a result that works. But I don't understand it. Searching google and stackoverflow leads to no help. Please explain and further generalize! Thanks!

import numpy as np
# gives a.b per point:
nx=5
ny=8
nz=3
a = np.arange(nx*ny*nz*4).reshape(4, nx,ny,nz)
b = np.arange(nx*ny*1*4*4).reshape(4, 4, nx,ny,1)
ctrue=a*0.0
for ii in np.arange(0,nx):
    for jj in np.arange(0,ny):
        for kk in np.arange(0,nz):
           ctrue[:,ii,jj,kk] = np.tensordot(a[:,ii,jj,kk],b[:,:,ii,jj,0],axes=[0,1])

c2 = (a[:,None,None,None] * b[:,:,None,None,None]).sum(axis=1).reshape(4,nx,ny,nz)
np.sum(ctrue-c2)
# gives 0 as required


# gives b.a per point:
ctrue2=a*0.0
for ii in np.arange(0,nx):
    for jj in np.arange(0,ny):
        for kk in np.arange(0,nz):
                ctrue2[:,ii,jj,kk] = np.tensordot(a[:,ii,jj,kk],b[:,:,ii,jj,0],axes=[0,0])

btrans=np.transpose(b,(1,0,2,3,4))
c22 = (a[:,None,None,None] * btrans[:,:,None,None,None]).sum(axis=1).reshape(4,nx,ny,nz)
np.sum(ctrue2-c22)
# gives 0 as required

# Note that only the single line for c2 and c22 are required -- the rest of the code is for testing/comparison to see if that line works.

# Issues/Questions:
# 1) Please explain why those things work and further generalize!
# 2) After reading about None=np.newaxis, I thought something like this would work:
    c22alt = (a[:,None,:,:,:] * btrans[:,:]).sum(axis=1).reshape(4,nx,ny,nz)
    np.sum(ctrue2-c22alt)
# but it doesn't.
# 3) I don't see how to avoid assignment of a separate btrans.  An np.transpose on b[:,:,None,None,None] doesn't work.

Other related links: Numpy: Multiplying a matrix with a 3d tensor -- Suggestion How to use numpy with 'None' value in Python?

Community
  • 1
  • 1

1 Answers1

1

To start with, your code is horrendously overcomplicated. The products a.b and b.a can be simplified to:

c2 = (a * b).sum(axis=1)
c22 = (a * b.swapaxes(0, 1)).sum(axis=1)

Note that instead of np.sum(ctrue - c2) you should use np.all(ctrue == c2); the former could give the wrong result if the two methods just happen to give a result with the same sum!

Why does this work? Consider a single element:

a0 = a[:, 0, 0, 0]
b0 = b[:, :, 0, 0, 0]

Taking the tensor dot np.tensordot(a0, b0, axes=(0, 1)) is equivalent to (a0 * b0).sum(axis=1). This is because of broadcasting; the (4, ) shape of a0 is broadcast to the (4, 4) shape of b0 and the arrays are elementwise multiplied; then summing over the 1 axis gives the tensor dot.

For the other dot product, np.tensordot(a0, b0, axes=(0, 0)) is equivalent to (a0 * b0.T).sum(axis=1) where b0.T is the same as b0.transpose() which is the same as b0.swapaxes(0, 1). By transposing b0, a0 is effectively broadcasting against the other axis of b0; we could get the same result by (a0[:, None] * b0).sum(axis=0).

The great thing about NumPy elementwise operations is that you can completely ignore higher axes as long as their shapes correspond or can broadcast, so what works for a0 and b0 (mostly) works for a and b as well.

Finally, we can make this much clearer using Einstein summation:

c2 = np.einsum('i...,ji...->j...', a, b)
c22 = np.einsum('i...,ij...->j...', a, b)
ecatmur
  • 152,476
  • 27
  • 293
  • 366