My understanding is that tf.nn.dynamic_rnn
returns the output of an RNN cell (e.g. LSTM) at each time step as well as the final state. How can I access cell states in all time steps not just the last one? For example, I want to be able to average all the hidden states and then use it in the subsequent layer.
The following is how I define an LSTM cell and then unroll it using tf.nn.dynamic_rnn
. But this only gives the last cell state of the LSTM.
import tensorflow as tf
import numpy as np
# [batch-size, sequence-length, dimensions]
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)
outputs, last_state = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
out, last = sess.run([outputs, last_state], feed_dict=None)