The GRU model in pytorch outputs two objects: the output features as well as the hidden states. I understand that for classification one uses the output features, but I'm not entirely sure which of them. Specifically, in a typical decoder-encoder architecture that uses a GRU in the decoder part, one would typically only pass the last (time-wise, i.e., t = N, where N is the length of the input sequence) output to the encoder. Which part of the output tensor refers to this time-wise last output?
The GRU is created like so (note that it is bidirectional):
self.gru = nn.GRU(
700,
700,
bidirectional=True,
batch_first=True,
)
Given some embedding vector representing a piece of text of size 150x700, I use the GRU like so (150 is the sequence length, 700 the embedding dimension):
gru_out, gru_hidden = self.gru(embedding)
gru_out will be of shape 150x1400, where 150 is again the sequence length and 1400 is double the embedding dimension, which is because of the GRU being a bidirectional one (in terms of pytorch's documentation, hidden_size*num_directions).
If I only want to access the time-wise last output, do I need to access it like so?
tmp = gru_out.view(150, 2, 700)
last_out_first_direction = tmp[149, 0, :]
last_out_second_direction = tmp[149, 1, :]
While this technically seems right and is similar to the answer posted here, it would also require that the actual input sequence is always of length 150, whereas typically you have also shorter actual input sequences that are simply padded to be of length 150. However, in GRU one is typically interested in the last actual input token, which can thus also be at a position <150. What is a common way to access the actual last token or time-step (<=150) instead of only the technically last step (always =150)?
Side question: Is the output of the second direction reversed (since the direction in which information is passed through the GRU is also reversed compared to the first direction) so I should actually access last_out_second_direction = tmp[0, 1, :]
instead of tmp[149, 1, :]
?