As this question has multiple major parts, I've dedicated a Q&A to the core challenge: stateful backpropagation. This answer focuses on implementing the variable output step length.
Description:
- As validated in Case 5, we can take a bottom-up first approach. First we feed the complete input to
model_a
(A) - then, feed its outputs as input to model_b
(B), but this time one step at a time.
- Note that we must chain B's output steps per A's input step, not between A's input steps; i.e., in your diagram, gradient is to flow between
Out[0][1]
and Out[0][0]
, but not between Out[2][0]
and Out[0][1]
.
- For computing loss it won't matter whether we use a ragged or padded tensor; we must however use a padded tensor for writing to TensorArray.
- Loop logic in code below is general; specific attribute handling and hidden state passing, however, is hard-coded for simplicity, but can be rewritten for generality.
Code: at bottom.
Example:
- Here we predefine the number of iterations for B per input from A, but we can implement any arbitrary stopping logic. For example, we can take a
Dense
layer's output from B as a hidden state and check if its L2-norm exceeds a threshold.
- Per above, if
longest_step
is unknown to us, we can simply set it, which is common for NLP & other tasks with a STOP token.
- Alternatively, we may write to separate
TensorArrays
at every A's input with dynamic_size=True
; see "point of uncertainty" below.
- A valid concern is, how do we know gradients flow correctly? Note that we've validate them for both vertical and horizontal in the linked Q&A, but it didn't cover multiple output steps per an input step, for multiple input steps. See below.
Point of uncertainty: I'm not entirely sure whether gradients interact between e.g. Out[0][1]
and Out[2][0]
. I did, however, verify that gradients will not flow horizontally if we write to separate TensorArray
s for B's outputs per A's inputs (case 2); reimplementing for cases 4 & 5, grads will differ for both models, including lower one with a complete single horizontal pass.
Thus we must write to a unified TensorArray
. For such, as there are no ops leading from e.g. IR[1]
to Out[0][1]
, I can't see how TF would trace it as such - so it seems we're safe. Note, however, that in below example, using steps_at_t=[1]*6
will make gradient flow in the both model horizontally, as we're writing to a single TensorArray
and passing hidden states.
The examined case is confounded, however, with B being stateful at all steps; lifting this requirement, we might not need to write to a unified TensorArray
for all Out[0]
, Out[1]
, etc, but we must still test against something we know works, which is no longer as straightforward.
Example [code]:
import numpy as np
import tensorflow as tf
#%%# Make data & models, then fit ###########################################
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
msn = MultiStatefulNetwork(batch_shape=(2, 3, 4), steps_at_t=[3, 4, 2])
#%%#############################################
with tf.GradientTape(persistent=True) as tape:
outputs = msn(x0)
# shape: (3, 4, 2, 4), 0-padded
# We can pad labels accordingly.
# Note the (2, 4) model_b's output shape, which is a timestep slice;
# model_b is a *slice model*. Careful in implementing various logics
# which are and aren't intended to be stateful.
Methods:
Not the cleanest, nor most optimal code, but it works; room for improvement.
More importantly: I implemented this in Eager, and have no idea how it'll work in Graph, and making it work for both can be quite tricky. If needed, just run in Graph and compare all values as done in the "cases".
# ideally we won't `import tensorflow` at all; kept for code simplicity
import tensorflow as tf
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops
from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model
#######################################################################
class MultiStatefulNetwork():
def __init__(self, batch_shape=(2, 6, 4), steps_at_t=[]):
self.batch_shape=batch_shape
self.steps_at_t=steps_at_t
self.batch_size = batch_shape[0]
self.units = batch_shape[-1]
self._build_models()
def __call__(self, inputs):
outputs = self._forward_pass_a(inputs)
outputs = self._forward_pass_b(outputs)
return outputs
def _forward_pass_a(self, inputs):
return self.model_a(inputs, training=True)
def _forward_pass_b(self, inputs):
return model_rnn_outer(self.model_b, inputs, self.steps_at_t)
def _build_models(self):
ipt = Input(batch_shape=self.batch_shape)
out = SimpleRNN(self.units, return_sequences=True)(ipt)
self.model_a = Model(ipt, out)
ipt = Input(batch_shape=(self.batch_size, self.units))
sipt = Input(batch_shape=(self.batch_size, self.units))
out, state = SimpleRNNCell(4)(ipt, sipt)
self.model_b = Model([ipt, sipt], [out, state])
self.model_a.compile('sgd', 'mse')
self.model_b.compile('sgd', 'mse')
def inner_pass(model, inputs, states):
return model_rnn(model, inputs, states)
def model_rnn_outer(model, inputs, steps_at_t=[2, 2, 4, 3]):
def outer_step_function(inputs, states):
x, steps = inputs
x = array_ops.expand_dims(x, 0)
x = array_ops.tile(x, [steps, *[1] * (x.ndim - 1)]) # repeat steps times
output, new_states = inner_pass(model, x, states)
return output, new_states
(outer_steps, steps_at_t, longest_step, outer_t, initial_states,
output_ta, input_ta) = _process_args_outer(model, inputs, steps_at_t)
def _outer_step(outer_t, output_ta_t, *states):
current_input = [input_ta.read(outer_t), steps_at_t.read(outer_t)]
output, new_states = outer_step_function(current_input, tuple(states))
# pad if shorter than longest_step.
# model_b may output twice, but longest in `steps_at_t` is 4; then we need
# output.shape == (2, *model_b.output_shape) -> (4, *...)
# checking directly on `output` is more reliable than from `steps_at_t`
output = tf.cond(
tf.math.less(output.shape[0], longest_step),
lambda: tf.pad(output, [[0, longest_step - output.shape[0]],
*[[0, 0]] * (output.ndim - 1)]),
lambda: output)
output_ta_t = output_ta_t.write(outer_t, output)
return (outer_t + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_outer_step,
loop_vars=(outer_t, output_ta) + initial_states,
cond=lambda outer_t, *_: tf.math.less(outer_t, outer_steps))
output_ta = final_outputs[1]
outputs = output_ta.stack()
return outputs
def _process_args_outer(model, inputs, steps_at_t):
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
# (samples, timesteps, channels) -> (timesteps, samples, channels)
# iterating dim0 to feed (samples, channels) slices expected by RNN
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return array_ops.transpose(input_t, axes)
inputs = nest.map_structure(swap_batch_timestep, inputs)
assert inputs.shape[0] == len(steps_at_t)
outer_steps = array_ops.shape(inputs)[0] # model_a_steps
longest_step = max(steps_at_t)
steps_at_t = tensor_array_ops.TensorArray(
dtype=tf.int32, size=len(steps_at_t)).unstack(steps_at_t)
# assume single-input network, excluding states which are handled separately
input_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
size=outer_steps,
element_shape=tf.TensorShape(model.input_shape[0]),
tensor_array_name='outer_input_ta_0').unstack(inputs)
# TensorArray is used to write outputs at every timestep, but does not
# support RaggedTensor; thus we must make TensorArray such that column length
# is that of the longest outer step, # and pad model_b's outputs accordingly
element_shape = tf.TensorShape((longest_step, *model.output_shape[0]))
# overall shape: (outer_steps, longest_step, *model_b.output_shape)
# for every input / at each step we write in dim0 (outer_steps)
output_ta = tensor_array_ops.TensorArray(
dtype=model.output[0].dtype,
size=outer_steps,
element_shape=element_shape,
tensor_array_name='outer_output_ta_0')
outer_t = tf.constant(0, dtype='int32')
initial_states = (tf.zeros(model.input_shape[0], dtype='float32'),)
return (outer_steps, steps_at_t, longest_step, outer_t, initial_states,
output_ta, input_ta)
def model_rnn(model, inputs, states):
def step_function(inputs, states):
output, new_states = model([inputs, *states], training=True)
return output, new_states
initial_states = states
input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)
def _step(time, output_ta_t, *states):
current_input = input_ta.read(time)
output, new_states = step_function(current_input, tuple(states))
flat_state = nest.flatten(states)
flat_new_state = nest.flatten(new_states)
for state, new_state in zip(flat_state, flat_new_state):
if isinstance(new_state, ops.Tensor):
new_state.set_shape(state.shape)
output_ta_t = output_ta_t.write(time, output)
new_states = nest.pack_sequence_as(initial_states, flat_new_state)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + tuple(initial_states),
cond=lambda time, *_: tf.math.less(time, time_steps_t))
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = output_ta.stack()
return outputs, new_states
def _process_args(model, inputs):
time_steps_t = tf.constant(inputs.shape[0], dtype='int32')
# assume single-input network (excluding states)
input_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
size=time_steps_t,
tensor_array_name='input_ta_0').unstack(inputs)
# assume single-output network (excluding states)
output_ta = tensor_array_ops.TensorArray(
dtype=model.output[0].dtype,
size=time_steps_t,
element_shape=tf.TensorShape(model.output_shape[0]),
tensor_array_name='output_ta_0')
time = tf.constant(0, dtype='int32', name='time')
return input_ta, output_ta, time, time_steps_t