This question is similar to Tensorflow Keras modify model variable from callback. I am unable to get the solution there to work (maybe there have been changes in TensorFlow 2.x since the solution was posted).
Below is demo code. I apologise if there is a typo.
I want to use a callback to update a non trainable variable (weighted_add_layer.weight
) that affects the output of the layer.
I have tried many variants such as putting tf.keras.backend.set_value(weighted_add_layer.weight, value)
in update function
.
In all cases, after the model is compiled, fit
uses the value of weighted_add_layer.weight
at the time of compilation and does not update the value later.
class WeightedAddLayer(tf.keras.layers.Layer):
def __init__(self, weight=0.00, *args, **kwargs):
super(WeightedAddLayer, self).__init__(*args, **kwargs)
self.weight = tf.Variable(0., trainable=False)
def add(self, inputA, inputB):
return (self.weight * inputA + self.weight * inputB)
def update(self, weight):
tf.keras.backend.set_value(self.weight, weight)
input_A = tfkl.Input(
shape=(32),
batch_size=32,
)
input_B = tfkl.Input(
shape=(32),
batch_size=32,
)
weighted_add_layer = WeightedAddLayer()
output = weighted_add_layer.add(input_A, input_B)
model = tfk.Model(
inputs=[input_A, input_B],
outputs=[output],
)
model.compile(
optimizer='adam', loss=losses.MeanSquaredError()
)
# Custom callback function
def update_fun(epoch, steps=50):
weighted_add_layer.update(
tf.clip_by_value(
epoch / steps,
clip_value_min=tf.constant(0.0),
clip_value_max=tf.constant(1.0),)
)
# Custom callback
update_callback = tfk.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: update_fun(epoch)
)
# train model
history = model.fit(
x=train_data,
epochs=EPOCHS,
validation_data=valid_data,
callbacks=[update_callback],
)
Any suggestions? Thanks much!