6

For example, I got matrix A of shape (3,2,2), e.g.

[
[[1,1],[1,1]], 
[[2,2],[2,2]], 
[[3,3],[3,3]]
]

and matrix B of shape (2,2), e.g.

[[1, 1], [0,1]]

I would like to achieve c of shape (3,2,2) like:

c = np.zeros((3,2,2))
for i in range(len(A)):
    c[i] = np.dot(B, A[i,:,:])

which gives

[[[2. 2.]
  [1. 1.]]

 [[4. 4.]
 [2. 2.]]

 [[6. 6.]
 [3. 3.]]]

What is the most efficient way to achieve this?

Thanks.

dxli
  • 107
  • 6
  • np.tensordot(A,B,axes=((2),(1))) gives [[[2 1] [2 1]] [[4 2], [4 2]] [[6 3] [6 3]]], which is different from what's expected. What should be the correct way in my case, please? – dxli Jun 11 '18 at 12:36
  • What would be typical shapes of A and B? The performance of various methods would depend on them. – Divakar Jun 11 '18 at 13:28
  • A is of shape (n, 2, 2), where n can be rather large. B is always of shape (2, 2). Which method would be preferred please? – dxli Jun 11 '18 at 13:31

2 Answers2

3

Use np.tensordot and then swap axes. So, use one of these -

np.tensordot(B,A,axes=((1),(1))).swapaxes(0,1)
np.tensordot(A,B,axes=((1),(1))).swapaxes(1,2)

We can reshape A to 2D after swapping axes, use 2D matrix multiplication with np.dot and reshape and swap axes to maybe gain marginal performance boost.

Timings -

# Original approach
def orgapp(A,B):
    m = A.shape[0]
    n = B.shape[0]
    r = A.shape[2]
    c = np.zeros((m,n,r))
    for i in range(len(A)):
        c[i] = np.dot(B, A[i,:,:])
    return c  

In [91]: n = 10000
    ...: A = np.random.rand(n,2,2)
    ...: B = np.random.rand(2,2)

In [92]: %timeit orgapp(A,B)
100 loops, best of 3: 12.2 ms per loop

In [93]: %timeit np.tensordot(B,A,axes=((1),(1))).swapaxes(0,1)
1000 loops, best of 3: 191 µs per loop

In [94]: %timeit np.tensordot(A,B,axes=((1),(1))).swapaxes(1,2)
1000 loops, best of 3: 208 µs per loop

# @Bitwise's solution
In [95]: %timeit np.flip(np.dot(A,B).transpose((0,2,1)),1)
1000 loops, best of 3: 697 µs per loop
Divakar
  • 218,885
  • 19
  • 262
  • 358
1

Another solution:

np.flip(np.dot(A,B).transpose((0,2,1)),1)
Bitwise
  • 7,577
  • 6
  • 33
  • 50