3

I'm trying to implement the 1D self-attention block below using PyTorch:

enter image description here

proposed in the following paper. Below you can find my (provisional) attempt:

import torch.nn as nn
import torch

#INPUT shape ((B), CH, H, W)


class Self_Attention1D(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=3):
        
        super().__init__()
        
        self.pointwise_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
        
        self.pointwise_conv2 = nn.Conv1d(in_channels=out_channels, out_channels=in_channels, kernel_size=(1,1))
        
        self.phi = MLP(in_size = out_channels, out_size=32)
        
        self.psi = MLP(in_size = out_channels, out_size=32)
                
        self.gamma = MLP(in_size=32, out_size=out_channels)
                
    def forward(self, x):
                
        x = self.pointwise_conv1(x)
        
        phi = self.phi(x.transpose(1,3))
        
        psi = self.psi(x.transpose(1,3))
        
        delta = phi-psi
        
        gamma = self.gamma(delta).transpose(3,1)
        
        out = self.pointwise_conv2(torch.mul(gamma,x))
        
        return out



class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        
        self.layers = nn.Sequential(
            
            nn.Linear(in_size, 64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,out_size))
    
    def forward(self, x):
        
        out = self.layers(x)
        
        return out

I'm not sure at all that this is correct, as the operations in my implementation are happening globally while as displayed in the image we should compute some operation between each entry and its neighbours one at a time. I was initially tempted to instantiate a for loop to iteratively compute the neural networks delta,phi,psi for each entry, but I felt that it wasn't the right way to do that.

Apologies if this is trivial but I still don't have a huge experience in PyTorch.

James Arten
  • 523
  • 5
  • 16
  • It looks like the input with shape `(1,w,c)` is being sliced at the second dimension into green, red, blue. It is not clear from the picture what the gamma symbol "Mapping Function" is doing. The part going from the Self Attention Map to Generated SAM is also a bit unclear. Also as sidenote, pointwise convs can be computed with a normal linear layer: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L30 – Kevin Mar 21 '22 at 16:22
  • The gamma function is simply projecting the result into the same channel space as the pointwise conv in order to perform the Hadamard product. What it’s not clear to me is, how can we implement the boxed operation in such a way that the delta function takes each single entry **and** its neighbour *one at a time*. Because again, I think that my implementation is ignoring this *local* operation. – James Arten Mar 21 '22 at 16:45

0 Answers0