2

I'm building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a discussion. However for the network to be properly trained I need to reset the hidden state of the LSTM between some batches.

I'm using a variable to store the hidden_state of the LSTM for performance :

    with tf.variable_scope('Hidden_state'):
        hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
                                       tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
        # Arrange it to a tuple of LSTMStateTuple as needed
        l = tf.unstack(hidden_state, axis=0)
        rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
                                for idx in range(self.num_layers)])

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
                                          initial_state=rnn_tuple_state, time_major=True)

Now I'm confused on how to reset the hidden state. I've tried two solutions but it's not working :

First solution

Reset the "hidden_state" variable with :

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))

It does work and I think it's because the unstack and tuple construction are not "re-played" into the graph after running the rnn_state_zero_op operation.

Second solution

Following LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow I tried to reset the cell state with :

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)

It doesn't seem to work either.

Question

I've another solution in mind but it's guessing at best : I'm not keeping the state returned by tf.nn.dynamic_rnn, I've thought of it but I get a tuple and I can't find a way to build an op to reset the tuple.

At this point I've to admit that I don't quite understand the internal working of tensorflow and if it's even possible to do what I'm trying to do. Is there a proper way to do it ?

Thanks !

AMairesse
  • 391
  • 3
  • 9

2 Answers2

4

Thanks to this answer to another question I was able to find a way to have complete control on whether or not (and when) the internal state of the RNN should be reset to 0.

First you need to define some variables to store the state of the RNN, this way you will have control over it :

with tf.variable_scope('Hidden_state'):
    state_variables = []
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
        state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    rnn_tuple_state = tuple(state_variables)

Note that this version define directly the variables used by the LSTM, this is much better than the version in my question because you don't have to unstack and build the tuple, which add some ops to the graph that you cannot run explicitly.

Secondly build the RNN and retrieve the final state :

# Build the RNN
with tf.name_scope('LSTM'):
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                               sequence_length=input_seq_lengths,
                                               initial_state=rnn_tuple_state,
                                               time_major=True)

So now you have the new internal state of the RNN. You can define two ops to manage it.

The first one will update the variables for the next batch. So in the next batch the "initial_state" of the RNN will be fed with the final state of the previous batch :

# Define an op to keep the hidden state between batches
update_ops = []
for state_variable, new_state in zip(rnn_tuple_state, new_states):
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(new_state[0]),
                       state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_keep_state_op = tf.tuple(update_ops)

You should add this op to your session anytime you want to run a batch and keep the internal state.

Beware : if you run batch 1 with this op called then batch 2 will start with the batch 1 final state, but if you don't call it again when running batch 2 then batch 3 will start with batch 1 final state also. My advice is to add this op every time you run the RNN.

The second op will be used to reset the internal state of the RNN to zeros:

# Define an op to reset the hidden state to zeros
update_ops = []
for state_variable in rnn_tuple_state:
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                       state_variable[1].assign(tf.zeros_like(state_variable[1]))])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_state_zero_op = tf.tuple(update_ops)

You can call this op whenever you want to reset the internal state.

AMairesse
  • 391
  • 3
  • 9
  • Thank you. That helps a lot! can you also address the issue of using variable batch sizes? – ARAT Jun 20 '18 at 17:10
  • Hi, I've never looked into this question. In my use case I only have a problem when I reach the end of the training set. I've solved the problem by padding my input. Here is my code : https://github.com/inikdom/rnn-speech/blob/dev/models/AcousticModel.py you should look about the padding in lines 142 to 153. Maybe it can help. – AMairesse Jun 21 '18 at 20:42
  • i am on the phone so I will check the code later. I have a quick question though. You are padding short sequences with zeroes. But I am talking about batch size. Let’s assume you have 107 different sequences and your batch size is 10. So you will have 11 batches but batch size of the last batch (11st) will be 7. Do you also create additional 3 sequences for the last batch consisting of only 0s? – ARAT Jun 21 '18 at 21:33
  • Yes exactly, the padding I'm referring to is about creating additional sequences for the last batch. It's made of zeros and the input_length is also zero in order to have no effect on the learning. – AMairesse Jun 22 '18 at 20:38
0

Simplified version of AMairesse post for one LSTM layer:

zero_state = tf.zeros(shape=[1, units[-1]])
self.c_state = tf.Variable(zero_state, trainable=False)
self.h_state = tf.Variable(zero_state, trainable=False)
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)

self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)

# save or reset states
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]

or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder):

self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")

self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
  true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
  false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))
Max Tkachenko
  • 792
  • 1
  • 12
  • 30