11

I'm following Pytorch seq2seq tutorial and ittorch.bmm method is used like below:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

I understand why we need to multiply attention weight and encoder outputs.

What I don't quite understand is the reason why we need bmm method here. torch.bmm document says

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

enter image description here

aerin
  • 20,607
  • 28
  • 102
  • 140

3 Answers3

14

In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h where h is the hidden state size of the encoder (which is an RNN).

Now while decoding (during training) the input sequences are given one at a time, so the input is B x 1 x d and the decoder produces a tensor of shape B x 1 x h. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.

So, consider you have two tensors of shape T1 = B x S x h and T2 = B x 1 x h. So if you can do batch matrix multiplication as follows.

out = torch.bmm(T1, T2.transpose(1, 2))

Essentially you are multiplying a tensor of shape B x S x h with a tensor of shape B x h x 1 and it will result in B x S x 1 which is the attention weight for each batch.

Here, the attention weights B x S x 1 represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h by transposing first and it will result in a tensor of shape B x h x 1. And if you perform squeeze at dim=2, you will get a tensor of shape B x h which is your context vector.

This context vector (B x h) is usually concatenated to decoder's hidden state (B x 1 x h, squeeze dim=1) to predict the next token.

aerin
  • 20,607
  • 28
  • 102
  • 140
Wasi Ahmad
  • 35,739
  • 32
  • 114
  • 161
  • while you're right about the general implementation of seq2seq, in the tutorial the OP is asking about there's no batch (B=1), so the bmm is redundant - see my answer – ihadanny Dec 05 '20 at 21:32
3

The operations depicted in the above figure happens on the Decoder side of the Seq2Seq model. Meaning that encoder outputs are already in terms of batches (with mini-batch size samples). Consequently, attn_weights tensor should also be in batch mode.

Thus, in essence, the first dimension (zeroth axis in NumPy terminology) of the tensors attn_weights and encoder_outputs is the number of samples of mini-batch size. Thus, we need torch.bmm on these two tensors.

kmario23
  • 57,311
  • 13
  • 161
  • 150
  • while you're right about the general implementation of seq2seq, in the tutorial the OP is asking about there's no batch (B=1), so the bmm is redundant - see my answer – ihadanny Dec 05 '20 at 21:48
2

while @wasiahmad is right about the general implementation of seq2seq, in the mentioned tutorial there's no batch (B=1), and the bmm is just over-engineering and can be safely replaced with matmul with the exact same model quality and performance. See for yourself, replace this:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

with this:

        attn_applied = torch.matmul(attn_weights,
                                 encoder_outputs)
        output = torch.cat((embedded[0], attn_applied), 1)

and run the notebook.


Also, note that while @wasiahmad talks about the encoder input as B x S x d, in pytorch 1.7.0, the GRU which is the main engine of the encoder expects an input format of (seq_len, batch, input_size) by default. If you want to work with @wasiahmad format, pass the batch_first = True flag.

ihadanny
  • 4,377
  • 7
  • 45
  • 76