0

I have a microservice (fastapi) that serves predictions from a model that was pre-trained in a separate repo. Ideally, this microservice only has one job: given an input X, a pre-trained model, and a pickled my_transformer: serve predictions y. The microservice shouldn't know/care about what the model or my_transformer are or how they work.

my_transformer = sklearn.preprocessing.FunctionTransformer(
    func = lambda y: -np.log(y / 100),
    inverse_func = lambda y_transformed: np.exp(-y_transformed) * 100,
)

y_transformed = my_transformer.transform(y)
model.fit(X,y_transformed)

The API should just be able to load the transformer (via dill -- because lambda functions aren't pickleable via pickle or joblib), and call inverse_transform. For the most part, this works as expected... that is until I start using fastapi and/or uvicorn. The error/stack trace I'm getting is

Traceback (most recent call last):
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/uvicorn/protocols/http/h11_impl.py", line 428, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/uvicorn/middleware/proxy_headers.py", line 78, in __call__
    return await self.app(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/fastapi/applications.py", line 276, in __call__
    await super().__call__(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/applications.py", line 122, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/middleware/errors.py", line 184, in __call__
    raise exc
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/middleware/errors.py", line 162, in __call__
    await self.app(scope, receive, _send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/middleware/exceptions.py", line 79, in __call__
    raise exc
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/middleware/exceptions.py", line 68, in __call__
    await self.app(scope, receive, sender)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py", line 21, in __call__
    raise e
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
    await self.app(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/routing.py", line 718, in __call__
    await route.handle(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/routing.py", line 276, in handle
    await self.app(scope, receive, send)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/routing.py", line 66, in app
    response = await func(request)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/fastapi/routing.py", line 237, in app
    raw_response = await run_endpoint_function(
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/fastapi/routing.py", line 165, in run_endpoint_function
    return await run_in_threadpool(dependant.call, **values)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/starlette/concurrency.py", line 41, in run_in_threadpool
    return await anyio.to_thread.run_sync(func, *args)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 807, in run
    result = context.run(func, *args)
  File "/Users/[me]/repos/perpetua1/kpi_predictor/basic_app.py", line 23, in root
    my_transformer.transform(arr),
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/sklearn/preprocessing/_function_transformer.py", line 238, in transform
    return self._transform(X, func=self.func, kw_args=self.kw_args)
  File "/Users/[me]/.local/share/virtualenvs/[my_venv]/lib/python3.9/site-packages/sklearn/preprocessing/_function_transformer.py", line 310, in _transform
    return func(X, **(kw_args if kw_args else {}))
  File "/var/folders/td/_z1v0f3x703gj0wdnfms4h7c0000gp/T/ipykernel_9600/434310704.py", line 4, in <lambda>
NameError: name 'np' is not defined

I've narrowed it down to potentially be related to how uvicorn is run, but don't understand why it breaks. Applying the following fix to my (more complicated) project also doesn't work, so I'm HOPING that understanding why it's breaking will help inform how to solve the more complicated case.

A basic version of my app looks like:

import dill
import numpy as np
import pandas as pd
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel


class ReturnValue(BaseModel):
    value: float


app = FastAPI()


@app.post("/", response_model=list[ReturnValue])
def root():
    with open("my_transformer.dill", "rb") as io:
        my_transformer = dill.load(io)

    arr = np.random.uniform(0, 100, 10)

    return pd.DataFrame(
        my_transformer.transform(arr),
        columns=["value"],
    ).to_dict(orient="records")


if __name__ == "__main__":
    # uvicorn.run("basic_app:app", reload=True) # NameError
    uvicorn.run(app) # runs as expected for this simple case, doesn't work for the more complicated app
bernie
  • 23
  • 3
  • 1
    Try using cloudpickle instead of dill. It is better at managing global references such as np in your case. https://stackoverflow.com/questions/32757656/what-are-the-pitfalls-of-using-dill-to-serialise-scikit-learn-statsmodels-models#:~:text=dill%20can%20pickle%20more%20types,cloudpickle%20physically%20stores%20the%20dependencies. – ilmarinen May 31 '23 at 21:29
  • 1
    I'm the `dill` author. I'd disagree... and you'll note that my answer in the above link recommends using `dill.settings['recurse'] = True`, which is essentially uses the same global reference pickling strategy as `cloudpickle` but within `dill` so you can serializes a broader range of objects. I do agree that it's likely a global reference issue you are experiencing. Best is to try `cloudpickle` and also try changing the `dill` settings, and go with whatever works. – Mike McKerns Jun 01 '23 at 12:24
  • `cloudpickle` did the trick but I'll keep that in mind. Thanks for the response! I've used `dill` in the past and this was my first time ever having issues so I'm happy to hear it's likely user error. – bernie Jun 01 '23 at 23:36

0 Answers0