We generally wrote the RNNs in Sonnet to work on a single timestep basis, as for Reinforcement Learning you often need to run one timestep to pick an action, and without that action you can't get the next observation (and the next input timestep) from the environment. It's easy to unroll a single timestep module over a sequence using tf.nn.dynamic_rnn
(see below). We also have a wrapper which takes care of composing several RNN cores per timestep, which I believe is what you're looking to do. This has the advantage that the DeepCore
object supports the start state methods required for dynamic_rnn
, so it's API compatibe with LSTM or any other single-timestep module.
What you want to do should be achievable like this:
# Create a single-timestep RNN module by composing recurrent modules and
# non-recurrent ops.
model = snt.DeepRNN([
snt.LSTM(100),
tf.tanh,
snt.LSTM(100),
tf.tanh,
snt.LSTM(100),
tf.sigmoid
], skip_connections=False)
batch_size = 2
sequence_length = 3
input_size = 4
single_timestep_input = tf.random_uniform([batch_size, input_size])
sequence_input = tf.random_uniform([batch_size, sequence_length, input_size])
# Run the module on a single timestep
single_timestep_output, next_state = model(
single_timestep_input, model.initial_state(batch_size=batch_size))
# Unroll the module on a full sequence
sequence_output, final_state = tf.nn.dynamic_rnn(
core, sequence_input, dtype=tf.float32)
A few things to note - if you haven't already please have a look at the RNN example in the repository, as this shows a full graph mode training procedure setup around a fairly similar model.
Secondly, if you do end up needing to implement a more complex module that DeepRNN
allows for, it's important to thread the recurrent state in and out of the module. In your example you're making the input state internally, and l1_state
and l2_state
as output are effectively discarded, so this can't be properly trained. If DeepRNN wasn't available, your model would look like this:
class LSTMNetwork(snt.RNNCore): # Note we inherit from the RNN-specific subclass
def __init__(self, name="LSTMNetwork"):
super(Model, self).__init__(name=name)
with self._enter_variable_scope():
self.l1 = snt.LSTM(100)
self.l2 = snt.LSTM(100)
self.out = snt.LSTM(10)
def initial_state(self, batch_size):
return (self.l1.initial_state(batch_size),
self.l2.initial_state(batch_size),
self.out.initial_state(batch_size))
def _build(self, inputs, prev_state):
# separate the components of prev_state
l1_prev_state, l2_prev_state, out_prev_state = prev_state
l1_out, l1_next_state = self.l1(inputs, l1_prev_state)
l1_out = tf.tanh(l1_out)
l2_out, l2_next_state = self.l2(l1_out, l2_prev_state)
l2_out = tf.tanh(l2_out)
output, out_next_state = self.out(l2_out, out_prev_state)
# Output state of LSTMNetwork contains the output states of inner modules.
full_output_state = (l1_next_state, l2_next_state, out_next_state)
return tf.sigmoid(output), full_output_state
Finally, if you're using eager mode I would strongly encourage you to have a look at Sonnet 2 - it's a complete rewrite for TF 2 / Eager mode. It's not backwards compatible, but all the same kinds of module compositions are possible. Sonnet 1 was written primarily for Graph mode TF, and while it does work with Eager mode you'll probably encounter some things that aren't very convenient.
We worked closely with the TensorFlow team to make sure that TF 2 & Sonnet 2 work nicely together, so please have a look: (https://github.com/deepmind/sonnet/tree/v2). Sonnet 2 should be considered alpha, and is being actively developed, so we don't have loads of examples yet, but more will be added in the near future.