I've met some problem with running tests using FastAPI+SQLAlchemy and PostgreSQL, which leads to lots of errors (however, it works well on SQLite). I've created a repo with MVP app and Pytest on Docker Compose testing.
The basic error is sqlalchemy.exc.InterfaceError('cannot perform operation: another operation is in progress')
. This may be related to the app/DB initialization, though I checked that all the operations get performed sequentially. Also I tried to use single instance of TestClient
for the all tests, but got no better results. I hope to find a solution, a correct way for testing such apps
Here are the most important parts of the code:
app.py:
app = FastAPI()
some_items = dict()
@app.on_event("startup")
async def startup():
await create_database()
# Extract some data from env, local files, or S3
some_items["pi"] = 3.1415926535
some_items["eu"] = 2.7182818284
@app.post("/{name}")
async def create_obj(name: str, request: Request):
data = await request.json()
if data.get("code") in some_items:
data["value"] = some_items[data["code"]]
async with async_session() as session:
async with session.begin():
await create_object(session, name, data)
return JSONResponse(status_code=200, content=data)
else:
return JSONResponse(status_code=404, content={})
@app.get("/{name}")
async def get_connected_register(name: str):
async with async_session() as session:
async with session.begin():
objects = await get_objects(session, name)
result = []
for obj in objects:
result.append({
"id": obj.id, "name": obj.name, **obj.data,
})
return result
tests.py:
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module")
@pytest.mark.asyncio
async def get_db():
await delete_database()
await create_database()
@pytest.mark.parametrize("test_case", test_cases_post)
def test_post(get_db, test_case):
with TestClient(app)() as client:
response = client.post(f"/{test_case['name']}", json=test_case["data"])
assert response.status_code == test_case["res"]
@pytest.mark.parametrize("test_case", test_cases_get)
def test_get(get_db, test_case):
with TestClient(app)() as client:
response = client.get(f"/{test_case['name']}")
assert len(response.json()) == test_case["count"]
db.py:
DATABASE_URL = environ.get("DATABASE_URL", "sqlite+aiosqlite:///./test.db")
engine = create_async_engine(DATABASE_URL, future=True, echo=True)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
Base = declarative_base()
async def delete_database():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def create_database():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
class Model(Base):
__tablename__ = "smth"
id = Column(Integer, primary_key=True)
name = Column(String, nullable=False)
data = Column(JSON, nullable=False)
idx_main = Index("name", "id")
async def create_object(db: Session, name: str, data: dict):
connection = Model(name=name, data=data)
db.add(connection)
await db.flush()
async def get_objects(db: Session, name: str):
raw_q = select(Model) \
.where(Model.name == name) \
.order_by(Model.id)
q = await db.execute(raw_q)
return q.scalars().all()