For cleaning up after tests even when they fail (and setting up before tests), pytest provides pytest.fixture
.
In your case you want to create all tables before each test, and drop them again afterwards. This can be achieved with the following fixture:
@pytest.fixture()
def test_db():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
And then use it in your tests like so:
def test_get_empty_todos_list(test_db):
response = client.get('/todos/')
assert response.status_code == 200
assert response.json() == []
For each test that has test_db
in its argument list pytest first runs Base.metadata.create_all(bind=engine)
, then yields to the test code, and afterwards makes sure that Base.metadata.drop_all(bind=engine)
gets run, even when the tests fail.
The full code:
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from main import app, get_db
from database import Base
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
@pytest.fixture()
def test_db():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_get_todos(test_db):
response = client.post("/todos/", json={"text": "some new todo"})
data1 = response.json()
response = client.post("/todos/", json={"text": "some even newer todo"})
data2 = response.json()
assert data1["user_id"] == data2["user_id"]
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == [
{"id": data1["id"], "user_id": data1["user_id"], "text": data1["text"]},
{"id": data2["id"], "user_id": data2["user_id"], "text": data2["text"]},
]
def test_get_empty_todos_list(test_db):
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == []
As your application grows, setting up and tearing down the whole database for each test might get slow.
A solution for that is to only set up the db once and then never actually commit anything to it.
This can be achieved using nested transactions and rollbacks:
import pytest
import sqlalchemy as sa
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker
from database import Base
from main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = sa.create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Set up the database once
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# These two event listeners are only needed for sqlite for proper
# SAVEPOINT / nested transaction support. Other databases like postgres
# don't need them.
# From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@sa.event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
# This fixture is the main difference to before. It creates a nested
# transaction, recreates it when the application code calls session.commit
# and rolls it back at the end.
# Based on: https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
@pytest.fixture()
def session():
connection = engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# Begin a nested transaction (using SAVEPOINT).
nested = connection.begin_nested()
# If the application code calls session.commit, it will end the nested
# transaction. Need to start a new one when that happens.
@sa.event.listens_for(session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested
if not nested.is_active:
nested = connection.begin_nested()
yield session
# Rollback the overall transaction, restoring the state before the test ran.
session.close()
transaction.rollback()
connection.close()
# A fixture for the fastapi test client which depends on the
# previous session fixture. Instead of creating a new session in the
# dependency override as before, it uses the one provided by the
# session fixture.
@pytest.fixture()
def client(session):
def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
yield TestClient(app)
del app.dependency_overrides[get_db]
def test_get_empty_todos_list(client):
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == []
Having two fixtures (session
and client
) here has an additional advantage:
If a test only talks to the API, then you don't need to remember adding the db fixture explicitly (but it will still be invoked implicitly).
And if you want to write a test that directly talks to the db, you can do that as well:
def test_something(session):
session.query(...)
Or both, if you for example want to prepare the db state before an API call:
def test_something_else(client, session):
session.add(...)
session.commit()
client.get(...)
Both the application code and test code will see the same state of the db.