I got it to work for no embedding using a very rudimentary InferenceHelper
:
inference_helper = tf.contrib.seq2seq.InferenceHelper(
sample_fn=lambda outputs: outputs,
sample_shape=[dim],
sample_dtype=dtypes.float32,
start_inputs=start_tokens,
end_fn=lambda sample_ids: False)
My inputs are floats with the shape [batch_size, time, dim]
. For the example below dim
would be 1, but this can easily be extended to more dimensions. Here's the relevant part of the code:
projection_layer = tf.layers.Dense(
units=1, # = dim
kernel_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=0.1))
# Training Decoder
training_decoder_output = None
with tf.variable_scope("decode"):
# output_data doesn't exist during prediction phase.
if output_data is not None:
# Prepend the "go" token
go_tokens = tf.constant(go_token, shape=[batch_size, 1, 1])
dec_input = tf.concat([go_tokens, target_data], axis=1)
# Helper for the training process.
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=dec_input,
sequence_length=[output_size] * batch_size)
# Basic decoder
training_decoder = tf.contrib.seq2seq.BasicDecoder(
dec_cell, training_helper, enc_state, projection_layer)
# Perform dynamic decoding using the decoder
training_decoder_output = tf.contrib.seq2seq.dynamic_decode(
training_decoder, impute_finished=True,
maximum_iterations=output_size)[0]
# Inference Decoder
# Reuses the same parameters trained by the training process.
with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
start_tokens = tf.constant(
go_token, shape=[batch_size, 1])
# The sample_ids are the actual output in this case (not dealing with any logits here).
# My end_fn is always False because I'm working with a generator that will stop giving
# more data. You may extend the end_fn as you wish. E.g. you can append end_tokens
# and make end_fn be true when the sample_id is the end token.
inference_helper = tf.contrib.seq2seq.InferenceHelper(
sample_fn=lambda outputs: outputs,
sample_shape=[1], # again because dim=1
sample_dtype=dtypes.float32,
start_inputs=start_tokens,
end_fn=lambda sample_ids: False)
# Basic decoder
inference_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell,
inference_helper,
enc_state,
projection_layer)
# Perform dynamic decoding using the decoder
inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
inference_decoder, impute_finished=True,
maximum_iterations=output_size)[0]
Have a look at this question. Also I found this tutorial to be very useful to understand seq2seq models, although it does use embeddings. So replace their GreedyEmbeddingHelper
by an InferenceHelper
like the one I posted above.
P.s. I posted the full code at https://github.com/Andreea-G/tensorflow_examples