I am aware that there is a similar topic at LSTM Followed by Mean Pooling, but that is about Keras and I work in pure TensorFlow.
I have an LSTM network where the recurrence is handled by:
outputs, final_state = tf.nn.dynamic_rnn(cell,
embed,
sequence_length=seq_lengths,
initial_state=initial_state)
where I pass the correct sequence lengths for each sample (padding by zeros). In any case, outputs contains irrelevant outputs since some samples produce longer outputs than others, based on sequence lengths.
Right now I'm extracting the last relevant output by means of the following method:
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)
where I pass sequence_length - 1
as indices. In reference to the last topic, I would like to select all relevant outputs followed by average pooling, instead of just the last one.
Now, I tried passing nested lists as indeces to extract_axis_1
but tf.stack
does not accept this.
Any solution directions for this?