1

I have two PyTorch tensors. One is rank three and the other is rank four. Is there a way to get it so that it produce the rank and shape of the first tensor? For instance in this cross-attention bit:

q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3, 8)
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3, 8)
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)

k = k.permute(0, 3, 2, 1)
attn = torch.einsum("nchw,nwhu->nchu", q, k)

# Below is what doesn't work. I would like to get it such that hidden_states is a tensor of rank 2, 4, 24
hidden_states = torch.einsum("chw,whu->chu", attn, v)

Is there a permutation/transpose I could apply to q, k, v, or attn that would allow me to multiply into (2, 4, 24)? I have yet to find one.

I currently receive this error: "RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given" so I'm wondering how to use the ellipsis in this case, if that could be a solution.

Any explanation as to why this is or isn't possible would also be an excepted answer!

MScottWaller
  • 3,321
  • 2
  • 24
  • 47
  • 1
    Please provide a [mre]. IIUC, there are serveral ways to use `einsum`. An example with smaller tensors: `A = torch.rand(2,5,12); B = torch.rand(5,3,5,4); torch.einsum('ijk,jlk->ijk', A, B.permute(0,2,1,3).reshape(5,5,12))`, but no way to verify. – Michael Szczesny Oct 03 '22 at 04:38
  • Thanks @MichaelSzczesny ! I've written what I hope is a much better question – MScottWaller Oct 03 '22 at 12:08

1 Answers1

1

It seems like your q and k are 4D tensors of shape batch-channel-height-width (2x4x3x8). However, when considering attention mechanism, one disregard the spatial arrangement of the features and only treat them as a "bag of features". That is, instead of q and k of shape 2x4x3x8 you should have 2x4x24:

q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)

attn = torch.einsum("bcn,bcN->bnN", q, k)
# it is customary to convert the raw attn into probabilities using softmax
attn = torch.softmax(attn, dim=-1)
hidden_states = torch.einsum("bnN,bcN->bcn", attn, v)

Shai
  • 111,146
  • 38
  • 238
  • 371
  • 1
    Upvoting for the excellent info, and marking as correct. I have a bigger issue at play, but this does provide and answer to the previous question. – MScottWaller Oct 03 '22 at 20:01