I am trying to find the best way to pass the LSTM state between batches. I have searched everything but I could not find a solution for the current implementation. Imagine I have something like:
cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)
Now I would like to pass the new_state
in each batch efficiently, so without storing it back to memory and then re-feed to tf using feed_dict
. To be more precise, all the solutions I found use sess.run
to evaluate new_state
and feed-dict
to pass it into init_state
. Is there any way to do so without having the bottleneck of using feed-dict
?
I think I should use tf.assign
in some way but the doc is incomplete and I could not find any workaround.
I want to thank everybody that will ask in advance.
Cheers,
Francesco Saverio
All the others answers that I found on stack overflow works for older version or use the 'feed-dict' method to pass the new state. For instance:
1) TensorFlow: Remember LSTM state for next batch (stateful LSTM) This works by using 'feed-dict' to feed the state placeholder and I want to avoid that
2) Tensorflow - LSTM state reuse within batch This does not work with the state turple
3) Saving LSTM RNN state between runs in Tensorflow Same here