20

What the difference between att_mask and key_padding_mask in MultiHeadAttnetion of pytorch:

key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Thanks in advance.

one
  • 2,205
  • 1
  • 15
  • 37

2 Answers2

28

The key_padding_mask is used to mask out positions that are padding, i.e., after the end of the input sequence. This is always specific to the input batch and depends on how long are the sequence in the batch compared to the longest one. It is a 2D tensor of shape batch size × input length.

On the other hand, attn_mask says what key-value pairs are valid. In a Transformer decoder, a triangle mask is used to simulate the inference time and prevent the attending to the "future" positions. This is what att_mask is usually used for. If it is a 2D tensor, the shape is input length × input length. You can also have a mask that is specific to every item in a batch. In that case, you can use a 3D tensor of shape (batch size × num heads) × input length × input length. (So, in theory, you can simulate key_padding_mask with a 3D att_mask.)

tigertang
  • 445
  • 1
  • 6
  • 18
Jindřich
  • 10,270
  • 2
  • 23
  • 44
  • 1
    What would be the purpose of having a mask that is specific to every item in the batch? Curious. – Brofessor Aug 06 '20 at 15:59
  • there could be padding at diff positions for each item in batch. For e.g. if input is a series of sentences, and they are padded at the beginning or end, we need to apply a individual mask to each sentence. This mask will be a combination of attn_mask and key_padding_mask in case of a decoder (referring to encoder inputs for key, values) – Allohvk Dec 29 '20 at 13:49
  • 1
    when passing masks for each item in a batch, does the module use sequential items along the 0 dimension for each attention head? i.e. when `batch_size=32` and `num_heads=4`, are `att_mask[:4,:,:]` the masks for item 1 (for head 1, 2, 3 and 4)? – skurp Sep 23 '21 at 20:36
3

I think they work as the same: Both of the mask defines which attention between query and key will not be used. And the only difference between the two choices is in which shape you are more comfortable to input the mask

According to the code, it seems like the two mask are merged/taken union so they all play the same role -- which attention between query and key will not be used. As they are taken union: the two mask inputs can be different valued if it is necessary that you are using two masks, or you can input the mask in whichever mask_args according to whose required shape is convenient: Here is part of the original code from pytorch/functional.py around line 5227 in the function multi_head_attention_forward()

...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

Please correct me if you have different opinions or I am wrong.

Flora Sun
  • 31
  • 4