14

How can I list all Tensorflow variables/constants/placeholders a node depends on?

Example 1 (addition of constants):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))

I would like to have a function list_dependencies() such as:

  • list_dependencies(d) returns ['a', 'b']
  • list_dependencies(e) returns ['a', 'b', 'c']

Example 2 (matrix multiplication between a placeholder and a weight matrix, followed by the addition of a bias vector):

tf.set_random_seed(1)
input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                "W",
                shape=[input_size, output_size],
                initializer=tf.contrib.layers.xavier_initializer())
b           = tf.get_variable(
                "b",
                shape=[output_size],
                initializer=tf.constant_initializer(2))
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))

I would like to have a function list_dependencies() such as:

  • list_dependencies(output) returns ['W', 'input']
  • list_dependencies(output_bias) returns ['W', 'b', 'input']
Franck Dernoncourt
  • 77,520
  • 72
  • 342
  • 501

4 Answers4

16

Here are utilities I use for this (from https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

These functions work on ops. To get an op that produces tensor t, use t.op. To get tensors produced by op op, use op.outputs

Yaroslav Bulatov
  • 57,332
  • 22
  • 139
  • 197
  • 2
    It might be a good idea to contribute that in [graph_util](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/python/framework/graph_util_impl.py?q=file:third_party/tensorflow.*graph_util&sq=package:piper+file://depot/google3+-file:google3/experimental&dr&l=110), or via contrib. – drpng Feb 15 '17 at 20:55
  • 1
    It would seem that this solution will return all of the child ops in the graph, not just those of a particular node. – Johiasburg Frowell Mar 04 '18 at 19:12
  • why did tensorflow have to be sooo bad man.. just like windows it wastes tons of human hours – figs_and_nuts Apr 24 '19 at 12:50
3

Yaroslav Bulatov's answer is great, I'll just add one plotting function that uses Yaroslav's get_graph() and children() method:

import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

plot_graph(get_graph())

Plotting the example 1 from the question:

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""
  print('get_graph')
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))
plot_graph(get_graph())

output:

enter image description here

Plotting the example 2 from the question:

enter image description here

If you use Microsoft Windows, you may run into this issue: Python Error (ValueError: _getfullpathname: embedded null character), in which case you need to patch matplotlib as the link explains.

Community
  • 1
  • 1
Franck Dernoncourt
  • 77,520
  • 72
  • 342
  • 501
  • 1
    btw, if you use jupyter, you can also use http://stackoverflow.com/questions/38189119/simple-way-to-visualize-a-tensorflow-graph-in-jupyter , that lets you collapse some nodes – Yaroslav Bulatov Feb 16 '17 at 00:00
1

These are all excellent answers, I will add a simple approach that produces the dependencies in a less easy to read format, but can be useful for quick debugging.

tf.get_default_graph().as_graph_def()

Printing that produces the operation in the graph as a simple dictionary shown below. Each OP is easy to spot by name with its attributes and inputs, allowing you to follow dependencies.

import tensorflow as tf

a = tf.placeholder(tf.float32, name='placeholder_1')
b = tf.placeholder(tf.float32, name='placeholder_2')
c = a + b

tf.get_default_graph().as_graph_def()

Out[14]: 
node {
  name: "placeholder_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "placeholder_2"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "add"
  op: "Add"
  input: "placeholder_1"
  input: "placeholder_2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 27
}
David Parks
  • 30,789
  • 47
  • 185
  • 328
0

In some cases one may want to find all the "input" variables which are connected to an "output" tensor, such as the loss of a graph. To this aim the following code snipped may be useful (inspired by the above code):

def findVars(atensor):
    allinputs=atensor.op.inputs
    if len(allinputs)==0:
        if atensor.op.type == 'VariableV2' or atensor.op.type == 'Variable':
            return set([atensor.op])
    a=set()
    for t in allinputs:
        a=a | findVars(t)
    return a

This can be used in debugging to find out where a connection in the graph is missing.