5

My understanding is that tf.nn.dynamic_rnn returns the output of an RNN cell (e.g. LSTM) at each time step as well as the final state. How can I access cell states in all time steps not just the last one? For example, I want to be able to average all the hidden states and then use it in the subsequent layer.

The following is how I define an LSTM cell and then unroll it using tf.nn.dynamic_rnn. But this only gives the last cell state of the LSTM.

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)
CentAu
  • 10,660
  • 15
  • 59
  • 85
  • There is no reason to need to access internal states that are not part of the output. If this is your use case, I would look to define an RNN identical to an LSTM but that outputs its full state. – jasekp Jun 22 '17 at 15:52
  • Have a look to this QA: https://stackoverflow.com/q/39716241/4282745 – pfm Jun 22 '17 at 16:29
  • Or to this https://github.com/tensorflow/tensorflow/issues/5731#issuecomment-262151359 – pfm Jun 22 '17 at 16:29

2 Answers2

3

Something like this should work.

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)

This uses concatenated states, as I don't know if you can store an arbitrary number of tuple states. The states variable is of shape (batch_size, max_time_size, state_size).

jasekp
  • 990
  • 1
  • 8
  • 17
  • Can you elaborate a bit about how this CustomRNN code returns the middle states? I am trying to understand your code! – CentAu Jun 22 '17 at 17:04
  • 2
    An LSTM state is the combination of the output (m) and a hidden state (c). This code takes the output (m) and replaces it with with the concatenated state (c + m). Disregarding batch size, the output is a list of [(c1 + m1), (c2 + m2), ... ] instead of [m1, m2, ...]. – jasekp Jun 22 '17 at 17:13
  • So, this replaces the actual output (m) with hidden state (c), right (`return next_state, next_state` instead of `return m, new_state`)? Where are you concatenating output and hidden state (`m + c`)? – CentAu Jun 22 '17 at 17:22
  • 1
    `new_state` is (c + m) for an LSTM, so returning `new_state, new_state` will replace the output m with (c + m). See [this line](https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/python/ops/rnn_cell_impl.py#L586) in the implementation. – jasekp Jun 22 '17 at 17:25
1

I would point you to this thread (highlights from me):

You can write a variant of the LSTMCell that returns both state tensors as part of the output, if you need both c and h state for each time step. If you just need the h state, that's the output of each time step.

As @jasekp wrote in its comment, the output is really the h part of the state. Then the dynamic_rnn method will just stack all the h part across time (see the string doc of _dynamic_rnn_loop in this file):

def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  """Internal implementation of Dynamic RNN.
    [...]
    Returns:
    Tuple `(final_outputs, final_state)`.
    final_outputs:
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.
pfm
  • 6,210
  • 4
  • 39
  • 44
  • LSTMCell is just one cell and returns both the state and output if I'm not mistaken. I think the unrolling part by `tf.nn.dynamic_rnn` only returns the last step. So, I need to modify that? It's strange that there is not already a higher level solution for this. – CentAu Jun 22 '17 at 16:44