I'm trying to create this super simple example with Tensorflow and I clearly don't fully understand the API for Tensorflow.
I have the following code. It's not mine originally - I found it from some demo, but I can't remember where I found it, or else I would give the author credit. Apologies.
Saving the Trained Line Model
import tensorflow as tf
import numpy as np
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
# Create a session saver
saver = tf.train.Saver()
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, 'linemodel')
Ok that's all fine. I just want to load in the model and then query my model to get a predicted value. Here is my attempted code:
Loading and Querying the Trained Line Model
# This is going to load the line model
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
v_ = sess.run(v)
print("This is {} with value: {}".format(v.name, v_))
# this works
# None of the below works
# Tried this as well
#fetches = {
# "input": tf.constant(10, name='input')
#}
#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')
# print(sess.run(query_value))
This is a really basic question, but how can I just pass in a value and use my line almost like a function. Do I need to change the way the line model is being constructed? My guess is that the computation graph is not set up where the output is an actual variable that we can get. Is this correct? If so, how should I modify this program?