Just for completeness I am adding just a bit on top of the answer of benjaminplanche. If your custom layer AttentionLayer
have any initial parameter that configure its behaviour you need to implement the get_config
method of the class. Otherwise it will fail to load. I am writing this because I had a lot of troubles on how to load custom layers with arguments, so I'll leave it here.
For example, a dummy implementation of your layer:
class AttentionLayer(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
return super().build(input_shape)
def call(self, x):
# Implementation about how to look with attention!
return x
def compute_output_shape(self, input_shape):
return input_shape
This will load with any of the approaches detailed in benjaminplanche's answer, i.e. using the custom_objects={'AttentionLayer': AttentionLayer}
. However if your layer have some arguments the loading would fail.
Imagine the init method of your class have 2 paramters:
class AttentionLayer(Layer):
def __init__(self, param1, param2, **kwargs):
self.param1 = param1
self.param2 = param2
super().__init__(**kwargs)
Then, when you load it with:
model = load_model('my_model.h5', custom_objects={'AttentionLayer': AttentionLayer})
It would throw this error:
Traceback (most recent call last):
File "/path/to/file/cstm_layer.py", line 62, in <module>
h = AttentionLayer()(x)
TypeError: __init__() missing 2 required positional arguments: 'param1' and 'param2'
In order to solve it you need to implement the get_config
method in your custom layer class. An example:
class AttentionLayer(Layer):
def __init__(self, param1, param2, **kwargs):
self.param1 = param1
self.param2 = param2
super().__init__(**kwargs)
# ...
def get_config(self):
# For serialization with 'custom_objects'
config = super().get_config()
config['param1'] = self.param1
config['param2'] = self.param2
return config
So when you save the model, the saving routine will call the get_config and will serialize the inner state of your custom layer, i.e., the self.params
. And when you load it, the loader will know how to initialize the inner state of your custom layer.