I'm implementing a custom Layer with the Keras API (working with TF2.0-beta
). I want to use the epoch number in my calculation in order to decay a parameter over time (meaning - in the call()
method).
I'm used to tf.get_global_step()
but understand that TF deprecated all global scopes, and definitely for a good reason.
If I had the model instance, I could use model.optimizer.iterations
, but I'm not sure how I get the instance of my parent model when I'm implementing a Layer.
Do I have any way to do that or the only way is to let the layer expose a Callback that will update the parameter I want to decay? Other ideas? Ideally something that wouldn't make the user of the layer aware of that inner detail (that's why I don't like the Callback approach - user has to add them to the model).