1

I want to develop a GRU-based model for variant length input data. So I think I should use the while statement in the forward and then break it when all of the sequences were processed. Will it affect the torch graph? Does this disturb the network gradient and the learning?

For example:

def forward(self, x):
    state = self.initial_state
    out = []
    for i in range(x.size(0)):
        state = self.rnn(x[i,], state)
        out.append(state)
        if condition:
            break
    return out, state

I searched but I didn't find any related information about it, and I don't know if this method is correct or not.

user16217248
  • 3,119
  • 19
  • 19
  • 37
Alireza AR
  • 11
  • 1

1 Answers1

0

The way Pytorch autograd works is by keeping track of operations involving a tensor that has requires_grad=True. If an operation never occurs because the loop broke before it was executed, it will never be tracked, and will have no effect on the gradient. Here's a simple example on how it works.

As you mentioned that you couldn't find any related information, I can refer you to a Pytorch tutorial that implements an RNN from "scratch".

CarlosGDCJ
  • 424
  • 1
  • 8
  • Thank you very much. As you mentioned, if my network runs 3 times with 7, 10, and 3 loops, and applies gradient each time, then each backpropagation step contains 7, 10, and 3 loops, respectively. – Alireza AR May 19 '23 at 12:09