0

In the Stack Overflow thread How can i add a Bi-LSTM layer on top of bert model?, there is a line of code:

hidden = torch.cat((lstm_output[:,-1, :256],lstm_output[:,0, 256:]),dim=-1)

Can someone explain why the concatenation of last and first tokens and not any other? What would these two tokens contain that they were chosen?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Karthik
  • 3
  • 2
  • OP is trying to do sequence classification using BERT. The first token of the generated sequence should be the 'CLS' token, which is used for classification (see https://datascience.stackexchange.com/questions/66207/what-is-purpose-of-the-cls-token-and-why-is-its-encoding-output-important). Because the model is bidirectional, I think they are trying to get the "first" token in both directions. – DWKOT Oct 21 '22 at 16:12

1 Answers1

0

In bidirectional models, hidden states gets concatenated at each step; so, the line basically concatenates the first :256 units of the last hidden state in the positive direction (-1) to the last 256: units of the last hidden state in the negative direction (0). Such locations contain the most "interesting" summary of the input sequence.

I've written a longer and detailed answer on how hidden states are constructed in PyTorch for recurrent modules.

ndrwnaguib
  • 5,623
  • 3
  • 28
  • 51