I'm building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a discussion. However for the network to be properly trained I need to reset the hidden state of the LSTM between some batches.
I'm using a variable to store the hidden_state of the LSTM for performance :
with tf.variable_scope('Hidden_state'):
hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
# Arrange it to a tuple of LSTMStateTuple as needed
l = tf.unstack(hidden_state, axis=0)
rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
for idx in range(self.num_layers)])
# Build the RNN
with tf.name_scope('LSTM'):
rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
initial_state=rnn_tuple_state, time_major=True)
Now I'm confused on how to reset the hidden state. I've tried two solutions but it's not working :
First solution
Reset the "hidden_state" variable with :
rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))
It does work and I think it's because the unstack and tuple construction are not "re-played" into the graph after running the rnn_state_zero_op operation.
Second solution
Following LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow I tried to reset the cell state with :
rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)
It doesn't seem to work either.
Question
I've another solution in mind but it's guessing at best : I'm not keeping the state returned by tf.nn.dynamic_rnn, I've thought of it but I get a tuple and I can't find a way to build an op to reset the tuple.
At this point I've to admit that I don't quite understand the internal working of tensorflow and if it's even possible to do what I'm trying to do. Is there a proper way to do it ?
Thanks !