2

I'm getting this error over and over again when trying to do self attention on 1D vectors, I don't really understand why that happens, any help would be greatly appreciated.

layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.ones(shape=[1, 16])
source = tf.ones(shape=[1, 16])
output_tensor, weights = layer(target, source)

The error:

~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/multi_head_attention.py in _masked_softmax(self, attention_scores, attention_mask)
    399         attention_mask = array_ops.expand_dims(
    400             attention_mask, axis=mask_expansion_axes)
--> 401     return self._softmax(attention_scores, attention_mask)
    402 
    403   def _compute_attention(self,

~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
   1010         with autocast_variable.enable_auto_cast_variables(
   1011             self._compute_dtype_object):
-> 1012           outputs = call_fn(inputs, *args, **kwargs)
   1013 
   1014         if self._activity_regularizer:

~/anaconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/keras/layers/advanced_activations.py in call(self, inputs, mask)
    332             inputs, axis=self.axis, keepdims=True))
    333       else:
--> 334         return K.softmax(inputs, axis=self.axis[0])
    335     return K.softmax(inputs, axis=self.axis)
    336 

IndexError: tuple index out of range
AloneTogether
  • 25,814
  • 5
  • 20
  • 39

1 Answers1

2

You are forgetting the batch dimension, which is necessary. Also if you want the output tensor and the corresponding weights, you have to set the parameter return_attention_scores to True. Try something like this:

import tensorflow as tf

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
samples = 5
target = tf.ones(shape=[samples, 1, 16])
source = tf.ones(shape=[samples, 1, 16])
output_tensor, weights = layer(target, source, return_attention_scores=True)

Also according to the docs:

query: Query Tensor of shape (B, T, dim)

value: Value Tensor of shape (B, S, dim)

AloneTogether
  • 25,814
  • 5
  • 20
  • 39