4

I have two numpy arrays a and b of shape [5, 5, 5] and [5, 5], respectively. For both a and b the first entry in the shape is the batch size. When I perform matrix multiplication option, I get an array of shape [5, 5, 5]. An MWE is as follows.

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)

Suppose I were to run a loop over the batch size, i.e. a[0] @ b[0].T, it would result in an array of shape [5, 1]. Finally, if I concatenate all the results along axis 1, I would get a resultant array with shape [5, 5]. The code below better describes these lines.

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
    c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)

Can I get the above functionality without using loop? For example, PyTorch provides a function called torch.bmm to compute this. Thanks.

learner
  • 3,168
  • 3
  • 18
  • 35

2 Answers2

5

Add an extra dimension to b to make the matrix multiplications batch compatible and remove the redundant last dimension at the end by squeezing:

c = np.matmul(a, b[:, :, None]).squeeze(-1)

Or equivalently:

c = (a @ b[:, :, None]).squeeze(-1)

Both make the matrix multiplication of a and b appropriate by reshaping b to 5 x 5 x 1 in your example.

swag2198
  • 2,546
  • 1
  • 7
  • 18
  • thanks, although your answer works for the case when `batch_size=5`, I am afraid for other batch sizes it would throw an error about the dimensions being mismatched. – learner Jul 17 '21 at 18:43
  • 1
    Can you provide the more general setup? It should work irrespective of the batch sizes as long as `a` is of shape _B x M x N_ and `b` of _B x N_ that will make `c` of shape _B x M_. – swag2198 Jul 18 '21 at 01:59
  • 1
    I had a transpose operator present, removing which the solution works, thanks! – learner Jul 18 '21 at 05:58
2

You can work this out using numpy einsum.

c = np.einsum('BNi,Bi ->BN', a, b)

Pytorch also provides this einsum function with slight change in syntax. So you can easily work it out. It easily handles other shapes as well.

Then you don't have to worry about transpose or squeeze operations. It also saves memory because no copy of existing matrices are created internally.

MSS
  • 3,306
  • 1
  • 19
  • 50