3

How does PyTorch's torch.nn.utils.rnn.pack_padded_sequence function work?

What is actually happening under the hood to stop PyTorch doing redundant computation and, relatedly, how does using this function when working with RNNs also stop the gradients for the RNN weights being computed on the basis of the padded token entries that appear in sequences?

Note: This is not a question asking about the behaviour of this function from the user level (i.e. the API) or the motivation to save this computation (as in this question).

Ideally some clarification on the implementation would be useful!

Anil
  • 1,097
  • 7
  • 20

0 Answers0