I am working on a Tensorflow NN which uses an LSTM to track a parameter (time series data regression problem). A batch of training data contains a batch_size of consecutive observations. I would like to use the LSTM state as input to the next sample. So, if I have a batch of data observations, I would like to feed the state of the first observation as input to the second observation and so on. Below I define the lstm state as a tensor of size = batch_size. I would like to reuse the state within a batch:
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state)
In the API there is a tf.nn.state_saving_rnn but the documentation is kinda vague. My question: How to reuse curr_state within a training batch.