I am trying to convert my tensorflow model for layers.MultiHeadAttention
module from tf.keras
to nn.MultiheadAttention
from torch.nn
module. Below are the snippets.
- Tensorflow Multi-head Attention
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
x_sfe_tf = np.random.randn(64, 345, 64)
x_te_tf = np.random.randn(64, 200, 64)
tes_mod_tf = layers.MultiHeadAttention(num_heads=2, key_dim=64)
output_tf = tes_mod_tf(x_sfe_tf, x_te_tf)
print(output_tf.shape)
- PyTorch Multi-head Attention
import torch
import torch.nn as nn
x_sfe_torch = torch.randn(64, 345, 64)
x_te_torch = torch.randn(64, 200, 64)
tes_mod_torch = nn.MultiheadAttention(embed_dim=64, num_heads=2)
output_torch = tes_mod_torch(x_sfe_torch, x_sfe_torch, x_te_torch)
print(output_torch.shape)
When I run the tensorflow's mha, it successfully returns (64, 345, 64)
. But when I run the pytorch's mha, it returns this error:
AssertionError: key shape torch.Size([64, 345, 64]) does not match value shape torch.Size([64, 200, 64])
The tensorflow version can return an output with the size of x_sfe, neglecting its size difference from x_te. In the other hand, pytorch version requires that x_sfe and x_te must have the same dimension. I am confused on how actually the tensorflow's Multi-head Attention module works? What is the difference between PyTorch and what is the correct input for the PyTorch? Thanks in advance.