1

I would like to perform the following batch of matrix multiplications

proj = torch.einsum('abi,aic->abc', A, B)

where A is an nxnxd tensor and B is an nxdxd tensor.

When n gets large ~50k, this operation becomes very slow.

However, A is actually sparse in the first two dimensions, i.e., it could actually be written as a set of indices (i,j) and a corresponding set of 1xd vectors.

Could someone help me how to speed this computation up?

Adam Gosztolai
  • 242
  • 1
  • 3
  • 14

0 Answers0