1

What's the best way to save the LSTM state between runs in Tensorflow? For the prediction phase, I need to pass in data one timestep at a time because the input of the next timestep relies on the output of the previous timestep.

I used the suggestion from this post: Tensorflow, best way to save state in RNNs? and tested it by passing in the same input over and over again without running the optimizer. If I understand correctly, if the output changes each time then it is saving the state but if it stays the same then it isn't. The result was that it saves the state the first time but then stays the same.

Here's my code:

 pieces = data_generator.load_pieces(5)

 batches = 100
 sizes = [126, 122]
 steps = 128
 layers = 2

 x = tf.placeholder(tf.float32, shape=[batches, steps, sizes[0]])
 y_ = tf.placeholder(tf.float32, shape=[batches, steps, sizes[1]])

 W = tf.Variable(tf.random_normal([sizes[0], sizes[1]]))
 b = tf.Variable(tf.random_normal([sizes[1]]))

 layer = tf.nn.rnn_cell.BasicLSTMCell(sizes[0], forget_bias=0.0)
 lstm = tf.nn.rnn_cell.MultiRNNCell([layer] * layers)

 # ~~~~~ code from linked post ~~~~~
 def get_state_variables(batch_size, cell):
     # For each layer, get the initial state and make a variable out of it
     # to enable updating its value.
     state_variables = []
     for state_c, state_h in cell.zero_state(batch_size, tf.float32):
         state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
             tf.Variable(state_c, trainable=False),
             tf.Variable(state_h, trainable=False)))
     # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
     return tuple(state_variables)

 states = get_state_variables(batches, lstm)

 outputs, new_states = tf.nn.dynamic_rnn(lstm, x, initial_state=states, dtype=tf.float32)

 def get_state_update_op(state_variables, new_states):
     # Add an operation to update the train states with the last state tensors
     update_ops = []
     for state_variable, new_state in zip(state_variables, new_states):
         # Assign the new state to the state variables on this layer
         update_ops.extend([state_variable[0].assign(new_state[0]),
                            state_variable[1].assign(new_state[1])])
     # Return a tuple in order to combine all update_ops into a single operation.
     # The tuple's actual value should not be used.
     return tf.tuple(update_ops)

 update_op = get_state_update_op(states, new_states)
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

 output = tf.reshape(outputs, [-1, sizes[0]])
 y = tf.nn.sigmoid(tf.matmul(output, W) + b)
 y = tf.reshape(y, [-1, steps, sizes[1]])

 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), [1, 2]))
 # train_step = tf.train.AdadeltaOptimizer().minimize(cross_entropy)

 sess = tf.InteractiveSession()
 sess.run(tf.global_variables_initializer())
 batch_x, batch_y = data_generator.get_batch(pieces)
 for i in range(500):
     error, _ = sess.run([cross_entropy, update_op], feed_dict={x: batch_x, y_: batch_y})
     print str(i) + ': ' + str(error)

Here's the error over time:

  • 0: 419.861
  • 1: 419.756
  • 2: 419.756
  • 3: 419.756 ...
Community
  • 1
  • 1
  • 2
    Possible duplicate of [TensorFlow: Remember LSTM state for next batch (stateful LSTM)](http://stackoverflow.com/questions/38241410/tensorflow-remember-lstm-state-for-next-batch-stateful-lstm) – Florentin Hennecker Jan 29 '17 at 18:49
  • I recommend you [this answer](https://stackoverflow.com/questions/44703593/tensorflow-how-to-access-all-the-middle-states-of-an-rnn-not-just-the-last-sta) which i tried few days ago. It works well. – winwin Oct 27 '17 at 08:59

1 Answers1

0

I recommend you this answer which i tried few days ago. It works well.

By the way, there's a way avoid setting state_is_tuple to false:

class CustomLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        # kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomLSTMCell, self).__init__(
            *args, **kwargs)  # create an lstm cell
        # change the output size to the state size
        self._output_size = np.sum(self._state_size)
        return returns

    def __call__(self, inputs, state):
        output, next_state = super(
            CustomLSTMCell, self).__call__(inputs, state)
        # return two copies of the state, instead of the output and the state
        return tf.reshape(next_state, shape=[1, -1]), next_state
winwin
  • 57
  • 6