0

There is a function tf.get_variable('name') which allows to "implicitly" pass parameters into function like:

def function(sess, feed):
    with tf.variable_scope('training', reuse=True):
        cost = tf.get_variable('cost')
    value = sess.run(cost, feed_dict=feed) 
    # other statements

But what if one want to pass a tf.placeholder into function? Is there same mechanism for placeholders, i.e. something like tf.get_placeholder():

def function(sess, cost, X_train, y_train):
    # Note this is NOT a valid TF code
    with tf.variable_scope('training', reuse=True):
        features = tf.get_placeholder('features')
        labels = tf.get_placeholder('labels')
    feed = {features: X_train, labels: y_train}
    value = sess.run(cost, feed_dict=feed)
    print('Cost: %s' % value)    

Or it doesn't make too much sense to do it and better to just construct placeholders inside of function?

devforfu
  • 1,570
  • 1
  • 19
  • 41

2 Answers2

1

Placeholders are just... placeholders. It's pointless "getting" a placeholder as if it has some sort of state (that's what get variable does, returns a variable in its current state).

Just use the same python variable everywhere.

Also, if you don't want to pass a python variable because your method signaturl becomes ugly, you can exploit the fact that you're building a graph and the graph itself contains the information about the declared placeholders.

You can do something like:

#define your placeholder
a = tf.placeholder(tf.float32, name="asd")

# then, when you need it, fetch if from the graph
graph = tf.get_default_graph()
placeholder = graph.get_tensor_by_name("asd:0")
nessuno
  • 26,493
  • 5
  • 83
  • 74
  • Ok, got it. So basically, placeholder is kind of "stateless" and creating new ones doesn't introduce any overhead, right? – devforfu Oct 06 '17 at 07:20
  • Yes. The overhead is introduced only when you use it, because of the data transfer from python to tensorflow – nessuno Oct 06 '17 at 09:13
1

Aside the fact that if you are working in the same script you should not need this, you can do that by getting the tensor by name, as in Tensorflow: How to get a tensor by name?

For instance

p = tf.placeholder(tf.float32)
p2 = tf.get_default_graph().get_tensor_by_name(p.name)

assert p == p2
Pietro Tortella
  • 1,084
  • 1
  • 6
  • 13