The Tensorflow 2 documentation states that users could save a Tensorflow Keras Model by calling the API model.save()
with either "SavedModel" or "h5" format (latest version 2.4.1: https://www.tensorflow.org/guide/keras/save_and_serialize#whole-model_saving_loading). Now assuming to use the "SavedModel" format, I am wondering if it is by design to periodically save checkpoints with the "SavedModel" format. For example,
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense
def get_model() -> Model:
"""
Define a TF Keras Model with layers having loss associated.
"""
x_in = Input(shape=(4,), name="input")
layer1 = Dense(64, name="l1")(x_in)
layer2 = Dense(64, name="l2")(layer1)
x_out = Dense(2, name="output")(layer2)
model = Model(inputs=x_in, outputs=x_out, name="m")
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[])
return model
def train_step(input_data: np.ndarray, label_data: np.ndarray, model: Model) -> None:
"""
Perform the training steps for the built Keras model, and periodically save
a "SavedModel" checkpoint -- Is it desired?
"""
# Perform a training step with input and label data.
with tf.GradientTape() as tape:
# A simple pass to mock the training step.
pass
# At the end of each training step, we save the updated model.
# But... is "SavedModel" the desired format for periodic saving checkpoints?
model.save("/tmp/saved_model/", save_format="tf")
def main() -> None:
model = get_model()
# Train 100 epoch
for epoch in range(100):
# each for inputs, labels in train_data:
train_step(input_data=np.array([1.0]), label_data=np.array([1.0]), model=model)
print(f"Epoch {epoch}, the training metric is...")
main()
I'm asking because my understanding is "SavedModel" is designed for saving a model only when the model is "ready" for deployment for inference (i.e. the model is trained well), and users save the "SavedModel" model only once (or O(1)
times) which usually happens in the end. One evidence on this is in Tensorflow 1 directly using tf.Session
with Tensorflow graph, if you periodically save the "SavedModel" model like the code below, then SavedModelBuilder
leaks one "Saver" node in the Tensorflow graph every time creating the builder:
import tensorflow as tf
from tensorflow import saved_model
def _save_to_saved_model(input_tensor: tf.Tensor, output_tensor: tf.Tensor, tf_session: tf.Session, saved_model_path: str) -> None:
"""
Save the Tensorflow graph to the Tensorflow saved model.
"""
# Create the saved model builder.
builder = saved_model.builder.SavedModelBuilder(saved_model_path)
# Build the tensor info proto using the tensors.
tensor_info_obs = saved_model.utils.build_tensor_info(input_tensor)
tensor_info_output = saved_model.utils.build_tensor_info(output_tensor)
# Get the default method name.
method_name = saved_model.signature_constants.PREDICT_METHOD_NAME
policy_signature = (
saved_model.signature_def_utils.build_signature_def(
inputs={"input": tensor_info_obs},
outputs={"output": tensor_info_output},
method_name=method_name))
# Get the signature def map key.
serving_signature_key = (
saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
builder.add_meta_graph_and_variables(
tf_session, [saved_model.tag_constants.SERVING],
signature_def_map={serving_signature_key: policy_signature})
# Save the saved model.
builder.save()
def main() -> None:
for _ in range(100):
# Mock each train step by only saving the "SavedModel".
_save_to_saved_model(...)
main()
Another question posted a few years ago seems also mentioned the same point: How to periodically save tensorflow model using saved_model API?. However, Tensorflow Keras doesn't seem to have the same leaking issue, as looks Model.save()
creates TrackableSaver
rather than Saver
, which doesn't leak the saver node in the Tensorflow graph, but I want to know using Tensorflow Keras Model if it is desired to periodically save checkpoints with "SavedModel" format.
NOTE: "SavedModel" format is being considered because looks it is the only Keras model persistence format that is allowed to restore model without accessing the custom model code.
Thanks!