Using LSTMCell i trained a model to do text generation . I started the tensorflow session and save all the tensorflow varibles using tf.global_variables_initializer() .
import tensorflow as tf
sess = tf.Session()
//code blocks
run_init_op = tf.global_variables_intializer()
sess.run(run_init_op)
saver = tf.train.Saver()
#varible that makes prediction
prediction = tf.nn.softmax(tf.matmul(last,weight)+bias)
#feed the inputdata into model and trained
#saved the model
#save the tensorflow model
save_path= saver.save(sess,'/tmp/text_generate_trained_model.ckpt')
print("Model saved in the path : {}".format(save_path))
The model get trained and saved all its session . Link to review the whole code lstm_rnn.py
Now i loaded the stored model and tried to do text generation for the document . So,i restored the model with following code
tf.reset_default_graph()
imported_data = tf.train.import_meta_graph('text_generate_trained_model.ckpt.meta')
with tf.Session() as sess:
imported_meta.restore(sess,tf.train.latest_checkpoint('./'))
#accessing the default graph which we restored
graph = tf.get_default_graph()
#op that we can be processed to get the output
#last is the tensor that is the prediction of the network
y_pred = graph.get_tensor_by_name("prediction:0")
#generate characters
for i in range(500):
x = np.reshape(pattern,(1,len(pattern),1))
x = x / float(n_vocab)
prediction = sess.run(y_pred,feed_dict=x)
index = np.argmax(prediction)
result = int_to_char[index]
seq_in = [int_to_char[value] for value in pattern]
sys.stdout.write(result)
patter.append(index)
pattern = pattern[1:len(pattern)]
print("\n Done...!")
sess.close()
I came to know that the prediction variable does not exist in the graph.
KeyError: "The name 'prediction:0' refers to a Tensor which does not exist. The operation, 'prediction', does not exist in the graph."
Full code is available here text_generation.py
Though i saved all tensorflow varibles , the prediction tensor is not saved in the tensorflow computation graph . whats wrong in my lstm_rnn.py file .
Thanks!