I have implemented a custom Transformer model using PyTorch. My model is primarily based on nn.TransformerEncoder and nn.TransformerEncoderLayer. Here is my code:
import torch.nn as nn
from torch import Tensor
import torch
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
self.max_len = max_len
def forward(self, x: Tensor) -> Tensor:
x = x + self.pe
return self.dropout(x)
class TransformerClassifier(nn.Module):
def __init__(self, max_len: int, ntoken: int, embed_dim: int, num_heads: int, num_layers: int, hidden_dim: int,
num_classes: int, dropout: float = 0.5):
super().__init__()
self.token_embedding = nn.Embedding(ntoken, embed_dim)
self.positional_embedding = PositionalEncoding(d_model=embed_dim, max_len=max_len)
encoder_layers = TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout, batch_first=True)
self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
self.fc = nn.Linear(embed_dim, num_classes)
self.embed_dim = embed_dim
def forward(self, x):
x = self.token_embedding(x) * math.sqrt(self.embed_dim)
x = self.positional_embedding(x)
x = self.transformer_encoder(x)
x = x.mean(dim=1)
x = self.fc(x)
return x
I want to extract the attention matrices of each individual attention head inside the MultiheadAttention module after running input data through the model, without modifying the code of the pretrained model. In my case, there are a total of num_heads attention matrices to extract from each MultiheadAttention module, where num_heads is the number of attention heads in the MultiheadAttention module.
How can I go about extracting the individual attention matrices from each attention head inside the MultiheadAttention module in my custom PyTorch Transformer model?