1

This is the standard batch matrix multiplication:

import torch
a = torch.arange(12, dtype=torch.float).view(2,3,2)
b = torch.arange(12, dtype=torch.float).view(2,3,2) - 1
c = a.matmul(b.transpose(-1,-2))
a,b,c

>> 
(tensor([[[ 0.,  1.],
          [ 2.,  3.],
          [ 4.,  5.]],
 
         [[ 6.,  7.],
          [ 8.,  9.],
          [10., 11.]]]),
 tensor([[[-1.,  0.],
          [ 1.,  2.],
          [ 3.,  4.]],
 
         [[ 5.,  6.],
          [ 7.,  8.],
          [ 9., 10.]]]),
 tensor([[[  0.,   2.,   4.],
          [ -2.,   8.,  18.],
          [ -4.,  14.,  32.]],
 
         [[ 72.,  98., 124.],
          [ 94., 128., 162.],
          [116., 158., 200.]]]))

This is the one that I have:

e = a.view(6,2)
f = b.view(6,2)
g = e.matmul(f.transpose(-1,-2))
e,f,g

>>
(tensor([[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]]),
 tensor([[-1.,  0.],
         [ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.],
         [ 7.,  8.],
         [ 9., 10.]]),
 tensor([[  0.,   2.,   4.,   6.,   8.,  10.],
         [ -2.,   8.,  18.,  28.,  38.,  48.],
         [ -4.,  14.,  32.,  50.,  68.,  86.],
         [ -6.,  20.,  46.,  72.,  98., 124.],
         [ -8.,  26.,  60.,  94., 128., 162.],
         [-10.,  32.,  74., 116., 158., 200.]]))

It's obvious that g covers c. I want to know if there is an efficient way to retrieve/slice c from g. Note that such retrieving/slicing method should generalize well to any shape of a and b.

namespace-Pt
  • 1,604
  • 1
  • 14
  • 25

1 Answers1

1

Got it. We can just slice g with fancy indexing. We just extract the matrix multiplication result within the same batch:

g = g.view(2,3,2,3)
res = g[range(2),:,range(2),:]
res
namespace-Pt
  • 1,604
  • 1
  • 14
  • 25