7

In pytorch, I train a RNN/GRU/LSTM network by starting the Backpropagation (Through Time) with :

loss.backward()

When the sequence is long, I'd like to do a Truncated Backpropagation Through Time instead of a normal Backpropagation Through Time where the whole sequence is used.

But I can't find in the Pytorch API any parameters or functions to set up the truncated BPTT. Did I miss it? Am I supposed to code it myself in Pytorch ?

u2gilles
  • 6,888
  • 7
  • 51
  • 75
  • 1
    Just use `h = h.detach()` at the point where you want to cut the backprop. See [`repackage_hidden()`](https://github.com/pytorch/examples/blob/master/word_language_model/main.py#L103) in the language modeling example. It effectively does the truncation. – dedObed Dec 24 '18 at 16:36
  • Thanks. In this code, which parameter controls the number of sequences I want to BPTT ? For example, the sequence length (args.bptt) is 35 in the code and let's say I want the BPTT to be done over just the last 5 sequences. What parameter is used for 5. – u2gilles Dec 25 '18 at 03:37

1 Answers1

1

Here is an example:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

So in this example, k is the parameter you use to control the timesteps you want to unroll.

angerhang
  • 327
  • 4
  • 13
  • 1
    Shouldnt you also do a `out.backward()` in the if clause as well? – Hossein Mar 25 '19 at 18:58
  • Yes if you want to backprop at every single valid timestep. The example shown is only doing the update on the very last valid timestep. – angerhang Mar 25 '19 at 22:01