related to this: How can I copy a variable in tensorflow
I am trying to copy the values of a lstm decoding units to use it elsewhere for beamsearch. in pseudo code, I would like something like this:
lstm_decode = tf.nn.rnn_cell(...)
training_output = tf.nn.seq2seq.rnn_decoder(...)
... do training by back-prop the error on trainint_output ...
# duplicate the lstm_decode unit (same weights)
lstm_decode_copy = copy(lstm_decode)
... do beam search with the duplicated lstm ...
The issue is that in tensorflow, the lstm variables are not generated during the call "tf.nn.rnn_cell(...)", but it is actually generated during the unrolling of the function call to rnn_decoder.
I could set the scope to the "tf.nn.seq2seq.rnn_decoder" function call, but the actual initialization of the lstm weights are not transparent to me. How might I capture these values and re-use them to make an lstm cell with the same weights as the ones learned?
thanks!