The following is my understanding. Correct me if I'm wrong.
I think the key to understand the computation of attention mask is the difference between the attention_mask for multi-head attention and the embedding mask generated by the embedding layer.
tf.keras.layers.Embedding
is a mask-generating layer.
With input shape of (batch_size, input_length), tf.keras.layers.Embedding
generates the embedding mask with the same shape (batch_size, input_length), (https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding#input-shape);
tf.keras.layers.MultiHeadAttention
is mask-consuming layer.
When the output tensor of tf.keras.layers.Embedding
is passed to tf.keras.layers.MultiHeadAttention
, the embedding mask also need to be passed to the latter layer. But tf.keras.layers.MultiHeadAttention
requires "attention_mask", which is different to the embedding mask. "attention_mask" is a boolean mask of shape (B, T, S) (https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention#call-arguments_1). B for batch_size, T for Target or Query, S for Source or Key.
To compute the attention mask for self attention, we basically need to do an outer product (https://en.wikipedia.org/wiki/Outer_product). This means, for a row token sequence $X$, we need to do $X^T X$. The outcome is the attention matrix where each element is the attention from one word to another. The attention mask would appear in the form of the matrix with the same shape.
The &
operator in mask1 & mask2
is tf.math.logical_and
.
A basic example to understand the attention mask in tf.keras.layers.MultiHeadAttention
sequence_a = "This is a very long sequence"
sequence_b = "This is short"
text = (sequence_a + ' ' + sequence_b).split(' ')
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
le.fit(text)
print(le.classes_)
['This' 'a' 'is' 'long' 'sequence' 'short' 'very']
_tokens_a = le.transform(sequence_a.split(' ')) + 1 # 1-based
# print(_tokens_a)
_tokens_b = le.transform(sequence_b.split(' ')) + 1
# print(_tokens_b)
pad_b = tf.constant([[0,_tokens_a.size - _tokens_b.size]])
tokens_b = tf.pad(_tokens_b, pad_b)
tokens_a = tf.constant(_tokens_a)
print(tokens_a)
tf.Tensor([1 3 2 7 4 5], shape=(6,), dtype=int64)
print(tokens_b)
tf.Tensor([1 3 6 0 0 0], shape=(6,), dtype=int64)
padded_batch = tf.concat([tokens_a[None,:], tokens_b[None,:]], axis=0)
padded_batch # Shape `(batch_size, input_seq_len)`.
Tokenization result:
<tf.Tensor: shape=(2, 6), dtype=int64, numpy=
array([[1, 3, 2, 7, 4, 5],
[1, 3, 6, 0, 0, 0]])>
Embedding mask and attention mask:
embedding = tf.keras.layers.Embedding(10, 4, mask_zero=True)
embedding_batch = embedding(padded_batch)
embedding_batch
<tf.Tensor: shape=(2, 6, 4), dtype=float32, numpy=
array([[[-0.0395105 , 0.02781621, -0.02362361, 0.01861998],
[ 0.02881015, 0.03395045, -0.0079098 , -0.002824 ],
[ 0.02268535, -0.02632991, 0.03217204, -0.03376112],
[ 0.04794324, 0.01584867, 0.02413819, 0.01202248],
[-0.03509659, 0.04907972, -0.00174795, -0.01215838],
[-0.03295932, 0.02424154, -0.04788723, -0.03202241]],
[[-0.0395105 , 0.02781621, -0.02362361, 0.01861998],
[ 0.02881015, 0.03395045, -0.0079098 , -0.002824 ],
[-0.02425164, -0.04932282, 0.0186419 , -0.01743554],
[-0.00052293, 0.01411307, -0.01286217, 0.00627784],
[-0.00052293, 0.01411307, -0.01286217, 0.00627784],
[-0.00052293, 0.01411307, -0.01286217, 0.00627784]]],
dtype=float32)>
embedding_mask = embedding_batch._keras_mask # embedding.compute_mask(padded_batch)
embedding_mask
<tf.Tensor: shape=(2, 6), dtype=bool, numpy=
array([[ True, True, True, True, True, True],
[ True, True, True, False, False, False]])>
#This is self attention, thus Q and K are the same
my_mask1 = embedding_mask[:, :, None] # eq: td[:,:,tf.newaxis]
my_mask1
<tf.Tensor: shape=(2, 6, 1), dtype=bool, numpy=
array([[[ True],
[ True],
[ True],
[ True],
[ True],
[ True]],
[[ True],
[ True],
[ True],
[False],
[False],
[False]]])>
#This is self attention, thus Q and K are the same
my_mask2 = embedding_mask[:, None, :]
my_mask2
<tf.Tensor: shape=(2, 1, 6), dtype=bool, numpy=
array([[[ True, True, True, True, True, True]],
[[ True, True, True, False, False, False]]])>
#According to the `attention_mask` argument of `tf.keras.layers.MultiHeadAttention` (https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention#call-arguments_1), this is the attention_mask which is a boolean mask of shape (B, T, S)
my_attention_mask = my_mask1 & my_mask2
my_attention_mask #[batch_size, input_seq_len, input_seq_len]
<tf.Tensor: shape=(2, 6, 6), dtype=bool, numpy=
array([[[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True]],
[[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]]])>