0

I am working on a Tensorflow NN which uses an LSTM to track a parameter (time series data regression problem). A batch of training data contains a batch_size of consecutive observations. I would like to use the LSTM state as input to the next sample. So, if I have a batch of data observations, I would like to feed the state of the first observation as input to the second observation and so on. Below I define the lstm state as a tensor of size = batch_size. I would like to reuse the state within a batch:

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

In the API there is a tf.nn.state_saving_rnn but the documentation is kinda vague. My question: How to reuse curr_state within a training batch.

Leeor
  • 627
  • 7
  • 24
  • To clarify, you'd like to thread the state from the result of the first batch element into the start state for the next batch element, and so on? Isn't the batch dimension exactly a time dimension in that case? – Allen Lavoie Feb 09 '17 at 17:52
  • @Allen Lavoie, yes that's correct. Each data observation within the batch is a (multi-dimensional) time series window. The batch contains overlapping windows arranged sequentially. The batch dimension is a time dimension, with overlap and stride. – Leeor Feb 09 '17 at 19:03
  • 1
    In that case your batch dimension is really 1. Unless you have multiple sequences you can batch together, this is going to be relatively slow. There is an effort in progress to support approximations which allow batching for single longer time series, but nothing has been released publicly yet. – Allen Lavoie Feb 09 '17 at 21:12
  • Thank you for the explanation! If you can explain a little bit more about how 'batching for single longer time series' will work and write out an answer I'll just mark it. – Leeor Feb 10 '17 at 07:30

1 Answers1

1

You are basically there, just need to update state with curr_state:

state_update = tf.assign(state, curr_state)

Then, make sure you either call run on state_update itself or an operation that has state_update as a dependency, or the assignment will not actually happen. For example:

with tf.control_dependencies([state_update]):
    model_output = ...

As suggested in the comments, the typical case for RNNs is that you have a batch where the first dimension (0) is the number of sequences and the second dimension (1) is the maximum length of each sequence (if you pass time_major=True when you build the RNN these two are swapped). Ideally, in order to get good performance, you stack multiple sequences into one batch, and then split that batch time-wise. But that's all a different topic really.

jdehesa
  • 58,456
  • 7
  • 77
  • 121