I'm building a tfx pipeline and using tensorflow serving to serve my model. I save the signature with model.save(...)
.
So far I was able to use the transform layer to transform the feature before prediction with tf_transform_output.transform_features_layer()
(see my code below).
However, I'm wondering how one can detect anomalies in the input data? For instance, I don't want to predict for an input value that is too far away from the distribution that a feature was trained with before.
The tfdv
library offers functions like generate_statistics_from_[csv|dataframe|tfrecord]
but I was not able to find any good example to generate statistics for serialized tf.Example
s (or something that is not saved in a file, like csv, tfrecords etc.).
I'm aware of the following example in the documentation:
import tensorflow_data_validation as tfdv
import tfx_bsl
import pyarrow as pa
decoder = tfx_bsl.coders.example_coder.ExamplesToRecordBatchDecoder()
example = decoder.DecodeBatch([serialized_tfexample])
options = tfdv.StatsOptions(schema=schema)
anomalies = tfdv.validate_instance(example, options)
But in this example serialized_tfexample
is a string, whereas in my code below the argument serialized_tf_examples
is a Tensor of strings.
Sorry if this is an obvious question. I spent all day to find a solution without success. Maybe I'm getting this all thing wrong. Maybe this is not the right place to put validations. So my more generalized question is actually: How do you validate incoming input data before prediction when you serve a model, which you created through a tfx pipeline, in production? I'm thankful for any lead into the right direction.
Here is my code to which I want to add validation:
...
tf_transform_output = tft.TFTransformOutput(...)
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
])
def serve_tf_examples_fn(serialized_tf_examples):
#### How can I generate stats and validate serialized_tf_examples? ###
#### Is this the right place? ###
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(TARGET_LABEL)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
...
model.save(serving_model_dir,
save_format='tf',
signatures={
'serving_default': serve_tf_examples_fn
})