0

This question is similar to Tensorflow Keras modify model variable from callback. I am unable to get the solution there to work (maybe there have been changes in TensorFlow 2.x since the solution was posted).

Below is demo code. I apologise if there is a typo.

I want to use a callback to update a non trainable variable (weighted_add_layer.weight) that affects the output of the layer.

I have tried many variants such as putting tf.keras.backend.set_value(weighted_add_layer.weight, value) in update function.

In all cases, after the model is compiled, fit uses the value of weighted_add_layer.weight at the time of compilation and does not update the value later.

class WeightedAddLayer(tf.keras.layers.Layer):
    def __init__(self, weight=0.00, *args, **kwargs):
        super(WeightedAddLayer, self).__init__(*args, **kwargs)
        self.weight = tf.Variable(0., trainable=False)

    def add(self, inputA, inputB):
        return (self.weight * inputA + self.weight * inputB)

    def update(self, weight):
        tf.keras.backend.set_value(self.weight, weight)
        
input_A = tfkl.Input(
    shape=(32),
    batch_size=32,
)

input_B = tfkl.Input(
    shape=(32),
    batch_size=32,
)

weighted_add_layer = WeightedAddLayer()

output = weighted_add_layer.add(input_A, input_B)

model = tfk.Model(
    inputs=[input_A, input_B],
    outputs=[output],
)
model.compile(
    optimizer='adam', loss=losses.MeanSquaredError()
)

# Custom callback function
def update_fun(epoch, steps=50):
    weighted_add_layer.update(
      tf.clip_by_value(
          epoch / steps,
          clip_value_min=tf.constant(0.0),
          clip_value_max=tf.constant(1.0),)
    )
    

# Custom callback
update_callback = tfk.callbacks.LambdaCallback(
    on_epoch_begin=lambda epoch, logs: update_fun(epoch)
)

# train model
history = model.fit(
    x=train_data,
    epochs=EPOCHS,
    validation_data=valid_data,
    callbacks=[update_callback],
)

Any suggestions? Thanks much!

  • I think it is possible, check this https://stackoverflow.com/a/75435498/9215780 – Innat Mar 11 '23 at 17:54
  • Interesting you use subclassed callback in that example whereas I was using a lambda callback. Thanks! I will try with your approach as soon as I get the chance. – Anirban Mukherjee Mar 11 '23 at 18:58

2 Answers2

1
  1. This could be an issue with TensorFlow 2.11.0 or my installation or something else I am missing but the use of lambda callbacks was both extremely unstable with my code base and bug checked constantly, and did not do what I wanted. It also led to odd behaviour that made it seem like there was a memory leak. The code for the complete model is very complex and I don't have the time to debug so I am sharing this information with a big FWIW caveat.

  2. The code in Is there a way to make a layer behave differently during forward pass for model.fit() and model.evaluate() in a customised Keras model? works. Some pointers:

a. You must have the tf.variable sit inside a layer and non trainable. I could not get this approach to work with a tf.variable outside a layer. That is not a big deal as one can always define a trivial layer that only scales an input or does some simple computation and use that layer to complete a task. I found that tf.variables outside a layer got optimised away by the compiler so there was no way to update post compilation.

b. The use of assign works well as an update device. I tried other approaches but I ended up with assign.

Here is a callback subclass that is consistent with the demo code. Note that when using the class you have to instantiate an instance of the class when calling fit. You cannot pass the name of the callback. Also note that this is not my real code but something I wrote to be consistent with the demo code above. It has not been tested and may have errors/typos.

class update_callback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None, steps=50):
        update_value = tf.clip_by_value(
            tf.cast((epoch + 1) / steps, dtype=tf.float32),
            clip_value_min=tf.constant(0.0, dtype=tf.float32),
            clip_value_max=tf.constant(1.0, dtype=tf.float32),
        ) # change this to what you want
        weighted_add_layer.weight.assign(update_value) #assign the update
0

Unfortunately based on this question, this question and related links in them, I don't think it is possible to freeze the layers after model.compile(). In your case, you have to save, freeze, then re-compile the model again.

Minh-Long Luu
  • 2,393
  • 1
  • 17
  • 39