So whenever using einsum
you best think about the meaning of the dimensions. Basically we perform a multiplication between the two inputs in this case. The signature passed to einsum
shows what dimensions will be preserved and which ones will be "summed away". I simplified the signature with single letters here:
res = einsum("b q m, n m h -> b q n h", x, y)
We can read from this that both x
and y
have three dimensions. Furthermore both have a dimension called m
, and this doesn't appear in the output. So we can conclude that it gets "summed away". So for each entry of the output we have following formula. For simplicity I reused the dimension names as indices, so for every b,q,n,h
we get
___
\
res[b,q,n,h] = / x[b,q,m] * y[n,m,h]
/__
m
To do this with any other function than einsum
is usually more cumbersome. So first we need to reorder and unsqueeze
the dimensions in a way that they are compatible to be multiplied, so we can do the following (the shapes annotated above):
#(b,q,m,n,h) (b, q, m, 1, 1) (m, n, h)
product = x[:, :, :, None, None] * y.permute([1,0,2])
Due to the broadcasting rules, the second (y
-) term will implicitly get the required leading dummy dimensions.
Then we can "sum away" the dimension m
:
res = product.sum(dim=2) # (b,q,n,h)
So you can interpret that as a matrix multiplication if you want, or also just a scalar product, but of course with many "batch"-dimensions.