Based on my answer to Create keras callback to save model predictions and targets for each batch during training, I use the following code:
"""Demonstrate activation histograms."""
import tensorflow as tf
from tensorflow import keras
class ActivationHistogramCallback(keras.callbacks.Callback):
"""Output activation histograms."""
def __init__(self, layers):
"""Initialize layer data."""
super().__init__()
self.layers = layers
self.batch_layer_outputs = {}
self.writer = tf.summary.create_file_writer("activations")
self.step = tf.Variable(0, dtype=tf.int64)
def set_model(self, _model):
"""Wrap layer calls to access layer activations."""
for layer in self.layers:
self.batch_layer_outputs[layer] = tf_nan(layer.output.dtype)
def outer_call(inputs, layer=layer, layer_call=layer.call):
outputs = layer_call(inputs)
self.batch_layer_outputs[layer].assign(outputs)
return outputs
layer.call = outer_call
def on_train_batch_end(self, _batch, _logs=None):
"""Write training batch histograms."""
with self.writer.as_default():
for layer, outputs in self.batch_layer_outputs.items():
if isinstance(layer, keras.layers.InputLayer):
continue
tf.summary.histogram(f"{layer.name}/output", outputs, step=self.step)
self.step.assign_add(1)
def tf_nan(dtype):
"""Create NaN variable of proper dtype and variable shape for assign()."""
return tf.Variable(float("nan"), dtype=dtype, shape=tf.TensorShape(None))
def main():
"""Run main."""
model = keras.Sequential([keras.layers.Dense(1, input_shape=(2,))])
callback = ActivationHistogramCallback(model.layers)
model.compile(loss="mse", optimizer="adam")
model.fit(
x=tf.transpose(tf.range(7.0) + [[0.2], [0.4]]),
y=tf.transpose(tf.range(7.0) + 10 + [[0.5]]),
validation_data=(
tf.transpose(tf.range(11.0) + 30 + [[0.6], [0.7]]),
tf.transpose(tf.range(11.0) + 40 + [[0.9]]),
),
shuffle=False,
batch_size=3,
epochs=2,
verbose=0,
callbacks=[callback],
)
if __name__ == "__main__":
main()
For the example training with 2 epochs and 3 batches (of unequal size due to the odd number of 7 training samples), one then sees the expected output (6 batches with 3, 3, 1, 3, 3, 1 peaks).

With 200 epochs (600 batches), one can also see training progress:
