6

I am trying to modify a non-trainable model variable from a callback on beginning of each epoch. Essentially I would like to have a mechanism similar to the learning rate scheduler (which has built in infrastructure in TF) but applicable to an arbitrary model variable. The code below is a minimum example to show the concept. I am trying to modify the decay variable but it does not work. Apparently the initial value of the variable (1.0) is treated as a constant and folded by the graph and never looked at again as training progresses even though the variable seems to be properly modified (to 0.5) by the callback.

dense1 = tf.keras.layers.Dense(10)
decay = tf.Variable(1.0, trainable=False)
dense2 = tf.keras.layers.Dense(10)

def epoch_callback(epoch):
    nonlocal decay
    tf.keras.backend.set_value(decay, 0.5)
    #decay.assign(0.5)
    print(tf.keras.backend.get_value(decay))

input = tf.keras.layers.Input((MAX_LENGTH,))
x = dense1(input)

with tf.control_dependencies([decay]):
    x = x * decay

prediction = dense2(x)

model = tf.keras.Model(inputs=[input], outputs=[prediction])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin = lambda epoch, logs: epoch_callback(epoch))]

model.fit(train_ds, epochs=EPOCHS, verbose=1, callbacks=callbacks, validation_data=eval_ds)

@nbro: Here you go. The code below is what worked for me. I use a teacher forcing protocol and the per-epoch decay variable is used to "lower teacher's voice" as training progresses.

class Teacher(tf.keras.layers.Layer):
    def __init__(self, embedding, name='teacher', **kwargs):
        super().__init__(name=name, **kwargs)
        ...

    def build(self, input_shape):
        ...

    def call(self, inputs, training=None):
        x, y, decay = inputs
        ...
        if training:
            y = tf.multiply(y, decay)
        else:
            y = tf.multiply(y, tf.constant(0.0))
        ...
        return x

    def get_config(self):
        return {}

class MyNet(tf.keras.Model):
    def __init__(self, name='mynet', **kwargs):
        super().__init__(name=name, **kwargs)

    def build(self, input_shape):
        ...
        self.teacher = Teacher()
        self.decay = tf.Variable(1.0, trainable=False)
        ...

    def set_decay(self, decay):
        self.decay.assign(decay)

    @tf.function
    def call(self, example, training=None):
        x, y = example
        ...
        x = self.teacher((x, y, self.decay))
        ...
        return x

    def get_config(self):
        return {}

def main():

    train_ds = ...
    eval_ds = ...

    train_ds = train_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    eval_ds = eval_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)


    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        the_net = MyNet()
        inputs = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='inputs')
        targets = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='targets')
        prediction = the_net((inputs, targets))
        model = tf.keras.Model(inputs=[inputs, targets], outputs=[prediction])
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=CosineSimilarity(name='val_loss'))

    def _callback_fun(epoch, start = 0, steps = 8):
        the_net.set_decay(tf.clip_by_value((start+steps-epoch)/steps, clip_value_min=tf.constant(0.0), clip_value_max=tf.constant(1.0)))

    callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: _callback_fun(epoch))]

    model.fit(train_ds, epochs=EPOCHS, verbose=2, callbacks=callbacks, validation_data=eval_ds)

if __name__ == '__main__':
    main()
cpajdel
  • 71
  • 6
  • Hi @cpajdel, Can you give more context on what you are trying to do? Do you want the variable to change every epoch? – TF_Support Apr 02 '20 at 10:01
  • Yes, that was the intention - the variable would be changed before epoch begins and then would be used by the computation graph.In fact, I got it working once I moved the logic above to a subclassed tf.kerasModel which exposes a method to update the per-epoch member tf.Variable. The method is called by the callback as above. Not quite clear though why the above sample would not work. – cpajdel Apr 03 '20 at 13:52
  • @cpajdel Can you provide the complete full solution to this problem? – nbro Apr 20 '20 at 20:16
  • @nbro I edited the example above, hope it helps – cpajdel Apr 23 '20 at 18:37
  • @cpajdel Is there a reason why you used `tf.Variable` rather than `self.add_weight`? – nbro Apr 23 '20 at 18:42
  • @nbro no particular reason, it was the first thing that worked when I was experimenting w/the code so I didn't bother any more; would there be some benefit I missed from using add_wieght instead? – cpajdel Apr 23 '20 at 18:58
  • @cpajdel I don't know. I only used `add_weight`, but I am getting problems, so I was wondering why you used `Variable`. Anyway, why did you use `Model` rather than building custom layers and then use the functional or sequential APIs? Any benefit to inherit from `Model`? – nbro Apr 23 '20 at 19:01
  • @nbro my topology uses two inputs (more specifically a tuple of (example, label)) is used as input (in addition to the regular label input, so in effect label is passed twice, see the map applied to the original dataset) and label is used for teacher forcing, I didn't figure out how to use Sequential in this scenario; as for Model vs. functional my observation was functional has lower performance (although the functional code was more concise and nicer) – cpajdel Apr 23 '20 at 19:10
  • @cpajdel - Is your issue resolved? If yes, can you share how did you manage to resolve it, for the community help. –  May 22 '20 at 10:10

0 Answers0