16

I have written an RNN language model using TensorFlow. The model is implemented as an RNN class. The graph structure is built in the constructor, while RNN.train and RNN.test methods run it.

I want to be able to reset the RNN state when I move to a new document in the training set, or when I want to run a validation set during training. I do this by managing the state inside the training loop, passing it into the graph via a feed dictionary.

In the constructor I define the the RNN like so

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

The training loop looks like this

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x and y are batches of training data in a document. The idea is that I pass the latest state along after each batch, except when I start a new document, when I zero out the state by running self.reset_state.

This all works. Now I want to change my RNN to use the recommended state_is_tuple=True. However, I don't know how to pass the more complicated LSTM state object via a feed dictionary. Also I don't know what arguments to pass to the self.state = tf.placeholder(...) line in my constructor.

What is the correct strategy here? There still isn't much example code or documentation for dynamic_rnn available.


TensorFlow issues 2695 and 2838 appear relevant.

A blog post on WILDML addresses these issues but doesn't directly spell out the answer.

See also TensorFlow: Remember LSTM state for next batch (stateful LSTM).

Community
  • 1
  • 1
W.P. McNeill
  • 16,336
  • 12
  • 75
  • 111
  • check out `rnn_cell._unpacked_state` and `rnn_cell._packed_state`. These are used in `rnn._dynamic_rnn_loop()` to pass the state as a list of argument tensors to the loop function. – JunkMechanic Aug 25 '16 at 02:30
  • I don't see the strings `_unpacked_state` and `_packed_state` in the latest TensorFlow source. Have these names changed? – W.P. McNeill Aug 28 '16 at 22:30
  • Hmm. Those have been removed. Instead, a new module `tf.python.util.nest` has been introduced with analogues `flatten` and `pack_sequence_as`. – JunkMechanic Aug 30 '16 at 06:01
  • 1
    Has anyone tried to update their code for TF1.0.1? The API has changed markedly. – Hephaestus Apr 18 '17 at 18:17

2 Answers2

22

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can't save the state between runs in tuples of LSTMStateTuple.

I solved this by saving the state in a tensor like this

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

You have two components in an LSTM layer, the cell state and hidden state, thats what the "2" comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf)

When building the graph you unpack and create the tuple state like this:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

Then you get the new state the usual way

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

It shouldn't be like this... perhaps they are working on a solution.

user1506145
  • 5,176
  • 11
  • 46
  • 75
  • 1
    If you only have one layer, does it become `state_placeholder = tf.placeholder(tf.float32, [2, batch_size, state_size])` and `initial_state = np.zeros((2, batch_size, state_size))`? – Lukeyb Sep 12 '17 at 04:18
3

A simple way to feed in an RNN state is to simply feed in both components of the state tuple individually.

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})
Casey Chu
  • 25,069
  • 10
  • 40
  • 59