0

In NMT using seq2seq architecture, during inference, we need the embedding variable trained during the training phase as an input to the GreedyEmbeddingHelper or the BeamSearchDecoder.

The question is, within the context of training and inferring using the Estimator API, how can we extract this trained embedding variable to be used for prediction?

cad86
  • 125
  • 8
  • 1
    Does https://stackoverflow.com/questions/37660685/how-to-get-tensorflow-seq2seq-embedding-output help you? – bantmen Mar 26 '18 at 01:43
  • Not really. In the Estimator API implementation of a seq2seq, the output embeddings are usually trained under an IF clause that can only be accessed during training and evaluation, since in these two phases you already know the output. For the prediction you don't, so you can't access that bit. Thanks for the reference though. – cad86 Mar 26 '18 at 09:06

1 Answers1

0

I figured out a solution based on the following stackoverflow answer. For the prediction phase, you can use the tf.contrib.framework.load_variable to retrieve the embedding variable from a trained and saved Tensorflow model as follows:

if mode == tf.estimator.ModeKeys.PREDICT:
    embeddings = tf.constant(tf.contrib.framework.load_variable('.','embed/embeddings'))
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
    start_tokens=tf.fill([batch_size], 1),end_token=0)

So in my case, I was running the code from the same folder containing the saved model, and my variable name was 'embed/embedding'. Note that this only works with embeddings trained via a tensorflow model. Otherwise, refer to the answer linked above.

To find the variable name using the estimator API, you can use the method get_variable_names() to get a list of all the variable names saved in the graph.

cad86
  • 125
  • 8