I've got a tensorflow multiple layer rnn cell like this:
def MakeLSTMCell(self):
cells = []
for n in self.numUnits:
cell = tf.nn.rnn_cell.LSTMCell(n)
dropout = tf.nn.rnn_cell.DropoutWrapper(cell,
input_keep_prob=self.keep_prob,
output_keep_prob=self.keep_prob)
cells.append(dropout)
stackedRNNCell = tf.nn.rnn_cell.MultiRNNCell(cells)
return stackedRNNCell
def BuildGraph(self):
"""
Build the Graph of the recurrent reinforcement neural network.
"""
with self.graph.as_default():
with tf.variable_scope(self.scope):
self.inputSeq = tf.placeholder(tf.float32, [None, None, self.observationDim], name='input_seq')
self.batch_size = tf.shape(self.inputSeq)[0]
self.seqLength = tf.shape(self.inputSeq)[1]
self.cell = self.MakeLSTMCell()
with tf.name_scope("LSTM_layers"):
self.zeroState = self.cell.zero_state(self.batch_size, tf.float32)
self.cellState = self.zeroState
self.outputs, self.outputState = tf.nn.dynamic_rnn(self.cell,
self.inputSeq,
initial_state=self.cellState,
swap_memory=True)
However, this self.cellState
is not configurable. I would like to know how could I save the lstm hidden state (keeps the same form so that I could feed it back to the rnn at any time) and reuse it at any time as initial_state
?
I've tried the accepted answer in this question: Tensorflow, best way to save state in RNNs? However, dynamic batch size is not allowed when creating tf Variable.
Any help will be appreciated