I am currently experimenting with my model, which uses Torchvision implementation of MViT_v2_s as backbone. I added a few cross attention modules to the model which looks roughly like this:
class FusionModule(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, source_a_input_channels: int, source_b_input_channels: int):
super().__init__()
# embed_dim = source_a_input_channels = source_b_input_channels
self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
self.source_a_pool = nn.LayerNorm()
self.source_b_pool = nn.LayerNorm()
self.proj_norm = nn.LayerNorm()
self.mlp = MLP() # a two layer mlp
def forward(self, source_a: torch.Tensor, source_b: torch.Tensor):
# source_a takes the output of a MViT multiscale block
source_a = self.source_a_pool(source_a)
# reshape source_b input to (b, thw, c)
source_b = self.source_b_pool(source_b.flatten(2).transpose(1, 2))
# after reshape, source_b has almost the same shape as source_a
# except source_b has one less token
fused = self.attn(source_a, source_b, source_b)[0]
mid_prod = source_a + fused
mid_prod = self.proj_norm(mid_prod)
out = self.mlp(mid_prod)
out = proj + out
return out
I added this module after every "stage" of the MViT backbone network (which is after the 1st, 3rd, 14th and 16th multiscale block).
My problem is, when I feed an empty tensor (torch.tensor([1, 3, 16, 224, 224])
) into my model, the calculation time of the first FusionModule is significantly longer than other modules. It takes about 2.7 seconds for the FusionModule to finish calculating the cross attention. Meanwhile, the first stage of the MViT backbone, which contains a single self-attention module and some other stuffs, takes only 0.2 seconds to finish its calculation.
Technically the amount of flops of the MViT backbone block should be almost the same as my FusionModule (the backbone block calculates self-attention on a [1, 25089, 96] tensor; my FusionModule calculates cross-attention between a query tensor of [1, 25089, 96] and key/value tensor of [1, 25088, 96]).
Is MViT_v2's implementation of self-attention really that much more efficient than pytorch's original implementation of MultiheadAttention? Or do they actually have a similar computational speed and this happened because I messed up my model?
I've confirmed that the calculation of mlp in my FusionModule only takes up a tiny fraction of that 2.7 seconds. Also the calculation was done on cpu. Apparently my laptop's gpu can't handle this model of mine.
ps: I've tried to run a vanila mvit_v2_s. It takes 1.2 seconds to process an input tensor, which is roughly the same amount of time my model spends on the backbone network. I assume this means there's nothing wrong with my model backbone.