If you're using tensorflow backend, apart from plot_model
, you can also use keras.callbacks.TensorBoard
callback to visualize the whole graph in tensorboard. Example:
callback = keras.callbacks.TensorBoard(log_dir='./graph',
histogram_freq=0,
write_graph=True,
write_images=True)
model.fit(..., callbacks=[callback])
Then run tensorboard --logdir ./graph
from the same directory.
This is a quick shortcut, but you can go even further with that.
For example, add tensorflow code to define (load) the model within custom tf.Graph
instance, like this:
from keras.layers import LSTM
import tensorflow as tf
my_graph = tf.Graph()
with my_graph.as_default():
# All ops / variables in the LSTM layer are created as part of our graph
x = tf.placeholder(tf.float32, shape=(None, 20, 64))
y = LSTM(32)(x)
.. after which you can list all graph nodes with dependencies, evaluate any variable, display the graph topology and so on, to compare the models.
I personally think, the simplest way is to setup your own session. It works in all cases with minimal patching:
import tensorflow as tf
from keras import backend as K
sess = tf.Session()
K.set_session(sess)
...
# Now can evaluate / access any node in this session, e.g. `sess.graph`