I am trying to display a histogram of all network weights (CNN) at each epoch in the Tensorboard using LambdaCallback
of Tensorflow 2 as follow:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist_weights(epoch, logs):
# predict images
Ws = model.get_weights()
with writer.as_default():
tf.summary.histogram("epoch: " + str(epoch), Ws)
return log_hist_weights
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
But the problem is get_weights()
returns all the network weights without any name (e.g. filter-weights,
BatchNormalization
weights, and other stuffs) but I am actually interested just in CNN-filter weights.
It would be great if I could implement something like this one in Tensorflow 2.
How can display a histogram of the filter-weights using Tensorflow?