I have a set of pre-processing stages in sklearn Pipeline
and an estimator which is a KerasClassifier
(from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
).
My overall goal is to tune and log the whole sklearn pipeline in mlflow
(in databricks evn). I get a confusing type error which I can't figure out how to reslove:
TypeError: can't pickle _thread.RLock objects
I have the following code (without tuning stage) which returns the above error:
conda_env = _mlflow_conda_env(
additional_conda_deps=None,
additional_pip_deps=[
"cloudpickle=={}".format(cloudpickle.__version__),
"scikit-learn=={}".format(sklearn.__version__),
"numpy=={}".format(np.__version__),
"tensorflow=={}".format(tf.__version__),
],
additional_conda_channels=None,
)
search_space = {
"estimator__dense_l1": 20,
"estimator__dense_l2": 20,
"estimator__learning_rate": 0.1,
"estimator__optimizer": "Adam",
}
def create_model(n):
model = Sequential()
model.add(Dense(int(n["estimator__dense_l1"]), activation="relu"))
model.add(Dense(int(n["estimator__dense_l2"]), activation="relu"))
model.add(Dense(1, activation="sigmoid"))
model.compile(
loss="binary_crossentropy",
optimizer=n["estimator__optimizer"],
metrics=["accuracy"],
)
return model
mlflow.sklearn.autolog()
with mlflow.start_run(nested=True) as run:
classfier = KerasClassifier(build_fn=create_model, n=search_space)
# fit the pipeline
clf = Pipeline(steps=[("preprocessor", preprocessor),
("estimator", classfier)])
h = clf.fit(
X_train,
y_train.values,
estimator__validation_split=0.2,
estimator__epochs=10,
estimator__verbose=2,
)
# log scores
acc_score = clf.score(X=X_test, y=y_test)
mlflow.log_metric("accuracy", acc_score)
signature = infer_signature(X_test, clf.predict(X_test))
# Log the model with a signature that defines the schema of the model's inputs and outputs.
mlflow.sklearn.log_model(
sk_model=clf, artifact_path="model",
signature=signature,
conda_env=conda_env
)
I also get this warning before the error:
WARNING mlflow.sklearn.utils: Truncated the value of the key `steps`. Truncated value: `[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,
transformer_weights=None,
transformers=[('num',
Pipeline(memory=None,
note the the whole pipeline runs outside mlflow. can someone help?