0

I'm trying to create this super simple example with Tensorflow and I clearly don't fully understand the API for Tensorflow.

I have the following code. It's not mine originally - I found it from some demo, but I can't remember where I found it, or else I would give the author credit. Apologies.

Saving the Trained Line Model

import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

# Create a session saver
saver = tf.train.Saver()

# Launch the graph.
sess = tf.Session() 

sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
        saver.save(sess, 'linemodel')

Ok that's all fine. I just want to load in the model and then query my model to get a predicted value. Here is my attempted code:

Loading and Querying the Trained Line Model

# This is going to load the line model
import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
    v_ = sess.run(v)
    print("This is {} with value: {}".format(v.name, v_))
    # this works


# None of the below works
# Tried this as well
#fetches = {
#   "input": tf.constant(10, name='input')
#}

#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')

# print(sess.run(query_value))

This is a really basic question, but how can I just pass in a value and use my line almost like a function. Do I need to change the way the line model is being constructed? My guess is that the computation graph is not set up where the output is an actual variable that we can get. Is this correct? If so, how should I modify this program?

jlarks32
  • 931
  • 8
  • 20

1 Answers1

2

You have to create tensorflow graph again and load saved weights into it. I added couple of lines to your code and it gives desired outputs. Please check it.

import tensorflow as tf
import numpy as np

sess = tf.Session() 
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()

# load saved weights into new variables
W = all_vars[0]
b = all_vars[1]

# build TF graph
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(W,x),b)

# Session
init = tf.global_variables_initializer()
print(sess.run(all_vars))
sess.run(init)    
for i in range(2):
    x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
    vals = sess.run(y,feed_dict={x:x_ip})
    print vals

Output:

[array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]

[-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206  -0.34291503 -0.54771954 -0.60995424 -0.91694558]
[-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816  -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]

I hope this helps.

Harsha Pokkalla
  • 1,792
  • 1
  • 12
  • 17
  • So you essentially have to rebuild the graph or what the model was doing? The structure doesn't really save? Like there's no way to get at `y` without saying `x = tf.placeholder(tf.float32) y = tf.add(tf.multiply(W,x),b)`? Thanks – jlarks32 Mar 01 '17 at 01:36
  • 1
    Yes, it does save only variables. We have to create graph again and reload saved variables. I found that there is tf.import_graph_def() function to load model definition which you can try. I will probably try it later. – Harsha Pokkalla Mar 01 '17 at 03:02
  • check this : http://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model-python/33762168#33762168 – Harsha Pokkalla Mar 01 '17 at 03:20
  • Cool! Thanks for the help as always. – jlarks32 Mar 01 '17 at 14:22