I am training built-in pytorch rnn modules (eg torch.nn.LSTM) and would like to add fixed-per-minibatch dropout between each time step (Gal dropout, if I understand correctly).
Most simply, I could unroll the network and compute my forward computation on a single batch something like this:
dropout = get_fixed_dropout()
for sequence in batch:
state = initial_state
for token in sequence:
state, output = rnn(token,state)
state, output = dropout(state, output)
outputs.append(output)
loss += loss(outputs,sequence)
loss.backward()
optimiser.step()
However, assuming that looping in python is pretty slow, I would much rather take advantage of pytorch's ability to fully process a batch of several sequences, all in one call (as in rnn(batch,state)
for regular forward computations).
i.e., I would prefer something that looks like this:
rnn_with_dropout = drop_neurons(rnn)
outputs, states = rnn_with_dropout(batch,initial_state)
loss = loss(outputs,batch)
loss.backward()
optimiser.step()
rnn = return_dropped(rnn_with_dropout)
(note: Pytorch's rnns do have a dropout
parameter, but it is for dropout between layers and not between time steps, and so not what I want).
My questions are:
Am I even understanding Gal dropout correctly?
Is an interface like the one I want available?
If not, would simply implementing some
drop_neurons(rnn), return_dropped(rnn)
functions that zero random rows in the rnn's weight matrices and bias vectors and then return their previous values after the update step be equivalent? (This imposes the same dropout between the layers as in between the steps, i.e. completely removes some neurons for the whole minibatch, and I'm not sure that doing this is 'correct').