In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d
where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h
where h is the hidden state size of the encoder (which is an RNN).
Now while decoding (during training)
the input sequences are given one at a time, so the input is B x 1 x d
and the decoder produces a tensor of shape B x 1 x h
. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.
So, consider you have two tensors of shape T1 = B x S x h
and T2 = B x 1 x h
. So if you can do batch matrix multiplication as follows.
out = torch.bmm(T1, T2.transpose(1, 2))
Essentially you are multiplying a tensor of shape B x S x h
with a tensor of shape B x h x 1
and it will result in B x S x 1
which is the attention weight for each batch.
Here, the attention weights B x S x 1
represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h
by transposing first and it will result in a tensor of shape B x h x 1
. And if you perform squeeze at dim=2, you will get a tensor of shape B x h
which is your context vector.
This context vector (B x h
) is usually concatenated to decoder's hidden state (B x 1 x h
, squeeze dim=1) to predict the next token.