I am trying to code up something similar to the positional encoding in the transformers paper. In order to do so I need to do the following:
For the following three matrices, I want to concatenate them at row level (i.e. the first row from each one stacked together, the second rows together, etc.), and then apply dot product between each matrix and its transpose, and finally, flatten them and stack them together. I'll clarify this in the following example:
x = torch.tensor([[1,1,1,1],
[2,2,2,2],
[3,3,3,3]])
y = torch.tensor([[0,0,0,0],
[0,0,0,0],
[0,0,0,0]])
z = torch.tensor([[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
concat = torch.cat([x, y, z], dim=-1).view(-1, x.shape[-1])
print(concat)
tensor([[1, 1, 1, 1], [0, 0, 0, 0], [4, 4, 4, 4], [2, 2, 2, 2], [0, 0, 0, 0], [5, 5, 5, 5], [3, 3, 3, 3], [0, 0, 0, 0], [6, 6, 6, 6]])
# Here I get each three rows together, and then apply dot product, flatten, and stack them.
concat = torch.stack([
torch.flatten(
torch.matmul(
concat[i:i+3, :], # 3 is the number of tensors (x,y,z)
torch.transpose(concat[i:i+3, :], 0, 1))
)
for i in range(0, concat.shape[0], 3)
])
print(concat)
tensor([[ 4, 0, 16, 0, 0, 0, 16, 0, 64], [ 16, 0, 40, 0, 0, 0, 40, 0, 100], [ 36, 0, 72, 0, 0, 0, 72, 0, 144]])
Finally, I was able to get the final matrix that I want. My question is, is there a way to achieve this without using a loop as I did in the final step? I want everything to be in tensors.