Files
Scraperr/api/backend/tests/conftest.py
Jayden Pyles 875a3684c9 Feat/swap to sqlalchemy (#99)
* chore: wip swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy
2025-07-12 21:12:33 -05:00

109 lines
3.0 KiB
Python

# STL
import os
import asyncio
from typing import Any, Generator, AsyncGenerator
# PDM
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from proxy import Proxy
from sqlalchemy import text
from sqlalchemy.pool import NullPool
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
# LOCAL
from api.backend.app import app
from api.backend.database.base import get_db
from api.backend.database.models import Base
from api.backend.tests.constants import TEST_DB_PATH
@pytest.fixture(scope="session", autouse=True)
def running_proxy():
proxy = Proxy(["--hostname", "127.0.0.1", "--port", "8080"])
proxy.setup()
yield proxy
proxy.shutdown()
@pytest.fixture(scope="session")
def test_db_path() -> str:
return TEST_DB_PATH
@pytest.fixture(scope="session", autouse=True)
def test_db(test_db_path: str) -> Generator[str, None, None]:
"""Create a fresh test database for each test function."""
os.makedirs(os.path.dirname(test_db_path), exist_ok=True)
if os.path.exists(test_db_path):
os.remove(test_db_path)
# Create async engine for test database
test_db_url = f"sqlite+aiosqlite:///{test_db_path}"
engine = create_async_engine(test_db_url, echo=False)
async def setup_db():
async with engine.begin() as conn:
# Create tables
# LOCAL
from api.backend.database.models import Base
await conn.run_sync(Base.metadata.create_all)
# Run setup
asyncio.run(setup_db())
yield test_db_path
if os.path.exists(test_db_path):
os.remove(test_db_path)
@pytest_asyncio.fixture(scope="session")
async def test_engine():
test_db_url = f"sqlite+aiosqlite:///{TEST_DB_PATH}"
engine = create_async_engine(test_db_url, poolclass=NullPool)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture(scope="function")
async def db_session(test_engine: Any) -> AsyncGenerator[AsyncSession, None]:
async_session = async_sessionmaker(
bind=test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
try:
yield session
finally:
# Truncate all tables after each test
for table in reversed(Base.metadata.sorted_tables):
await session.execute(text(f"DELETE FROM {table.name}"))
await session.commit()
@pytest.fixture()
def override_get_db(db_session: AsyncSession):
async def _override() -> AsyncGenerator[AsyncSession, None]:
yield db_session
return _override
@pytest_asyncio.fixture()
async def client(override_get_db: Any) -> AsyncGenerator[AsyncClient, None]:
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
app.dependency_overrides.clear()