3

In Tensorflow, is there a way to find all placeholder tensors that are required to evaluate a certain output tensor? That is, is there a function that will return all (placeholder) tensors that must be fed into feed_dict when sess.run(output_tensor) is called ?

Here's an example of what I'd like to do, in pseudocode:

import tensorflow as tf

a = tf.placeholder(dtype=tf.float32,shape=())
b = tf.placeholder(dtype=tf.float32,shape=())
c = tf.placeholder(dtype=tf.float32,shape=())
d = a + b
f = b + c

# This should return [a,b] or [a.name,b.name]
d_input_tensors = get_dependencies(d)

# This should return [b,c] or [b.name,c.name]
f_input_tensors = get_dependencies(f)

EDIT: To clarify, I am not (necessarily) looking for all of the placeholders in the graph, just the ones that are required for defining a particular output tensor. The desired placeholders are likely to be only a subset of all placeholders in the graph.

  • For getting all placeholders in the graph, there is an answer: https://stackoverflow.com/a/44371483/4834515. as for getting dependencies... no idea. – LI Xuhong Oct 11 '17 at 16:01
  • @Seven I'd like to get just the dependencies, not all of the placeholders. I'll edit my question to clarify. – Johiasburg Frowell Oct 11 '17 at 20:36

1 Answers1

3

After some tinkering and discovering this nearly identical SO question, I came up with the following solution:

def get_tensor_dependencies(tensor):

    # If a tensor is passed in, get its op
    try:
        tensor_op = tensor.op
    except:
        tensor_op = tensor

    # Recursively analyze inputs
    dependencies = []
    for inp in tensor_op.inputs:
        new_d = get_tensor_dependencies(inp)
        non_repeated = [d for d in new_d if d not in dependencies]
        dependencies = [*dependencies, *non_repeated]

    # If we've reached the "end", return the op's name
    if len(tensor_op.inputs) == 0:
        dependencies = [tensor_op.name]

    # Return a list of tensor op names
    return dependencies

Note: This will not only return placeholders, but also variables and constants. If dependencies = [tensor_op.name] is replaced by dependencies = [tensor_op.name] if tensor_op.type == 'Placeholder' else [], then only placeholders will be returned.