0

I have a piece of code (not mine) that defines a non-trainable variable that is used to define another property of the layer, which looks something like

initial_weight_val = 1.0
w = my_layer.add_weight(name=layer.name + '/my_weight', shape=(),
                       initializer=tf.initializers.constant(initial_weight_val), 
                       trainable=False)
# Use w to set another parameter of the layer.
my_layer.the_parameter = some_function(w)

Please, do not tell me what a non-trainable variable is (Of course, I know what it is?), which is also discussed in What is the definition of a non-trainable parameter?.

However, given that w will not be changed (I think), I don't get why someone would define such a variable, rather than simply using the Python variable initial_weight_val directly, especially when using TensorFlow 2.0 (which is my case and the only case I am interested in). Of course, one possibility would be that this variable could become trainable, in case one needs it to be trainable later, but why should one anticipate this, anyway?

Can I safely use initial_weight_val to define the_parameter, i.e. pass initial_weight_val to some_function rather than w?

I am concerned with this issue because I cannot save a model with a variable, because I get the error "variable is not JSON serializable" (Keras and TF are so buggy, btw!), so I was trying to understand the equivalence between user-defined non-trainable variables and Python variables.

nbro
  • 15,395
  • 32
  • 113
  • 196

1 Answers1

1

You must make sure that this value doesn't change at all, and that it's a single value.
Then yes, you can use a Python var (if a python var is compatible with the function that uses this w).

In which case, you'd put that initial_weight_val both in the __init__ and in the get_config methods of the layer in order for it to be properly saved.

Now, if the function only accepts tensors, but you're still sure that this value will not change at all, then you can on call make w = tf.constant(self.initial_weight_val). You still have the value in __init__ and in get_config as a python var.

Finally, if this value, although non-trainable is changing, or if it's a tensor with many elements, then you'd better let it be a weight. (Non-trainable means "non trainable by backpropagation", but still allowed to be updated here and there).
There should be absolutely no problem for saving loading this weight if you defined it correctly, which should be inside build, with self.add_weight(....), as shown in https://keras.io/layers/writing-your-own-keras-layers/ .


A cool Keras example that uses non-trainable but updatable weights is the BatchNormalization layer. The mean and std of the batches are updated every pass, but not via backpropagation (thus trainable=False)

Daniel Möller
  • 84,878
  • 18
  • 192
  • 214