This is the standard batch matrix multiplication:
import torch
a = torch.arange(12, dtype=torch.float).view(2,3,2)
b = torch.arange(12, dtype=torch.float).view(2,3,2) - 1
c = a.matmul(b.transpose(-1,-2))
a,b,c
>>
(tensor([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.]],
[[ 6., 7.],
[ 8., 9.],
[10., 11.]]]),
tensor([[[-1., 0.],
[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.],
[ 9., 10.]]]),
tensor([[[ 0., 2., 4.],
[ -2., 8., 18.],
[ -4., 14., 32.]],
[[ 72., 98., 124.],
[ 94., 128., 162.],
[116., 158., 200.]]]))
This is the one that I have:
e = a.view(6,2)
f = b.view(6,2)
g = e.matmul(f.transpose(-1,-2))
e,f,g
>>
(tensor([[ 0., 1.],
[ 2., 3.],
[ 4., 5.],
[ 6., 7.],
[ 8., 9.],
[10., 11.]]),
tensor([[-1., 0.],
[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.]]),
tensor([[ 0., 2., 4., 6., 8., 10.],
[ -2., 8., 18., 28., 38., 48.],
[ -4., 14., 32., 50., 68., 86.],
[ -6., 20., 46., 72., 98., 124.],
[ -8., 26., 60., 94., 128., 162.],
[-10., 32., 74., 116., 158., 200.]]))
It's obvious that g
covers c
. I want to know if there is an efficient way to retrieve/slice c
from g
. Note that such retrieving/slicing method should generalize well to any shape of a
and b
.