0
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()

        self.hidden_size = hidden_size
        # Create a two layer fully-connected network. Hint: Use nn.Sequential
        # hidden_size*2 --> hidden_size, ReLU, hidden_size --> 1
        self.attention_network = nn.Sequential(
                                            nn.Linear(hidden_size*2, hidden_size),
                                            nn.ReLU(),
                                            nn.Linear(hidden_size, 1))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, hidden, annotations):
        """The forward pass of the attention mechanism.

        Arguments:
            hidden: The current decoder hidden state. (batch_size x hidden_size)
            annotations: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)

        Returns:
            output: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)

            The output must be a softmax weighting over the seq_len annotations.
        """

        batch_size, seq_len, hid_size = annotations.size()
        expanded_hidden = hidden.unsqueeze(1).expand_as(annotations)

        # concat = ...
        # reshaped_for_attention_net = ...
        # attention_net_output = ...
        # unnormalized_attention = ...  # Reshape attention net output to have dimension batch_size x seq_len x 1

        return self.softmax(unnormalized_attention)



In the forward function this is what I've tried: concat = torch.cat((expanded_hidden, annotations), 2) unnormalized_attention = self.attention_network(concat)

I'm trying to figure out

        concat = ...
        reshaped_for_attention_net = ...
        attention_net_output = ...
        unnormalized_attention = ...
Ani_Expo
  • 1
  • 1
  • I think what you tried is correct already. It already produces an output tensor with the expected shape and sums to 1 in `dim=1`. No need to break it down into 4 steps. – kmkurn Dec 19 '22 at 21:44

0 Answers0