0

I've got a tensorflow multiple layer rnn cell like this:

def MakeLSTMCell(self):
    cells = []
    for n in self.numUnits:
        cell = tf.nn.rnn_cell.LSTMCell(n)
        dropout = tf.nn.rnn_cell.DropoutWrapper(cell,
                                                input_keep_prob=self.keep_prob,
                                                output_keep_prob=self.keep_prob)
        cells.append(dropout)
    stackedRNNCell = tf.nn.rnn_cell.MultiRNNCell(cells)
    return stackedRNNCell

def BuildGraph(self):
    """
    Build the Graph of the recurrent reinforcement neural network.
    """
    with self.graph.as_default():
        with tf.variable_scope(self.scope):
            self.inputSeq = tf.placeholder(tf.float32, [None, None, self.observationDim], name='input_seq')
            self.batch_size = tf.shape(self.inputSeq)[0]
            self.seqLength = tf.shape(self.inputSeq)[1]
            self.cell = self.MakeLSTMCell()

            with tf.name_scope("LSTM_layers"):
                self.zeroState = self.cell.zero_state(self.batch_size, tf.float32)
                self.cellState = self.zeroState

                self.outputs, self.outputState = tf.nn.dynamic_rnn(self.cell,
                                                         self.inputSeq,
                                                         initial_state=self.cellState,
                                                         swap_memory=True)

However, this self.cellState is not configurable. I would like to know how could I save the lstm hidden state (keeps the same form so that I could feed it back to the rnn at any time) and reuse it at any time as initial_state?

I've tried the accepted answer in this question: Tensorflow, best way to save state in RNNs? However, dynamic batch size is not allowed when creating tf Variable.

Any help will be appreciated

Kevin Fang
  • 1,966
  • 2
  • 16
  • 31
  • There are, I think, essentially two ways. One is saving the state in variables, which, like you said, does not support variable batch size, so you would have to fix it and pad incomplete batches. The other would be retrieving the output state on `run` and feed it back through a placeholder (or a placeholder with default, if you prefer not having to give the initial state by hand); the disadvantage here is there is some overhead getting the state out of and into TensorFlow, although it should not be anything terrible. – jdehesa Oct 22 '18 at 09:39
  • @jdehesa I finally make the workaround using a placeholder. Yeah indeed transforming between a tuple of StateTuple and numpy array is quite painful and makes the code obscure, well at last it works. – Kevin Fang Oct 22 '18 at 22:10

0 Answers0