Feat/swap to sqlalchemy (#99)

* chore: wip swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy

* feat: swap to sqlalchemy
This commit is contained in:
Jayden Pyles
2025-07-12 21:12:33 -05:00
committed by GitHub
parent b096fb1b3c
commit 875a3684c9
35 changed files with 1593 additions and 1363 deletions
+147
View File
@@ -0,0 +1,147 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+1
View File
@@ -0,0 +1 @@
Generic single-database configuration.
+103
View File
@@ -0,0 +1,103 @@
# STL
import os
import sys
from logging.config import fileConfig
# PDM
from dotenv import load_dotenv
from sqlalchemy import pool, engine_from_config
# LOCAL
from alembic import context
from api.backend.database.base import Base
from api.backend.database.models import Job, User, CronJob # type: ignore
load_dotenv()
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "api")))
# Load the raw async database URL
raw_database_url = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///data/database.db")
# Map async dialects to sync ones
driver_downgrade_map = {
"sqlite+aiosqlite": "sqlite",
"postgresql+asyncpg": "postgresql",
"mysql+aiomysql": "mysql",
}
# Extract scheme and convert if async
for async_driver, sync_driver in driver_downgrade_map.items():
if raw_database_url.startswith(async_driver + "://"):
sync_database_url = raw_database_url.replace(async_driver, sync_driver, 1)
break
else:
# No async driver detected — assume it's already sync
sync_database_url = raw_database_url
# Apply it to Alembic config
config = context.config
config.set_main_option("sqlalchemy.url", sync_database_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+28
View File
@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}
@@ -0,0 +1,67 @@
"""initial revision
Revision ID: 6aa921d2e637
Revises:
Create Date: 2025-07-12 20:17:44.448034
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '6aa921d2e637'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('email', sa.String(length=255), nullable=False),
sa.Column('hashed_password', sa.String(length=255), nullable=False),
sa.Column('full_name', sa.String(length=255), nullable=True),
sa.Column('disabled', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('email')
)
op.create_table('jobs',
sa.Column('id', sa.String(length=64), nullable=False),
sa.Column('url', sa.String(length=2048), nullable=False),
sa.Column('elements', sa.JSON(), nullable=False),
sa.Column('user', sa.String(length=255), nullable=True),
sa.Column('time_created', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('result', sa.JSON(), nullable=False),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('chat', sa.JSON(), nullable=True),
sa.Column('job_options', sa.JSON(), nullable=True),
sa.Column('agent_mode', sa.Boolean(), nullable=False),
sa.Column('prompt', sa.String(length=1024), nullable=True),
sa.Column('favorite', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['user'], ['users.email'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('cron_jobs',
sa.Column('id', sa.String(length=64), nullable=False),
sa.Column('user_email', sa.String(length=255), nullable=False),
sa.Column('job_id', sa.String(length=64), nullable=False),
sa.Column('cron_expression', sa.String(length=255), nullable=False),
sa.Column('time_created', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('time_updated', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ),
sa.ForeignKeyConstraint(['user_email'], ['users.email'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('cron_jobs')
op.drop_table('jobs')
op.drop_table('users')
# ### end Alembic commands ###
+1 -4
View File
@@ -15,7 +15,6 @@ from api.backend.scheduler import scheduler
from api.backend.ai.ai_router import ai_router
from api.backend.job.job_router import job_router
from api.backend.auth.auth_router import auth_router
from api.backend.database.startup import init_database
from api.backend.stats.stats_router import stats_router
from api.backend.job.cron_scheduling.cron_scheduling import start_cron_scheduler
@@ -36,10 +35,8 @@ async def lifespan(_: FastAPI):
# Startup
LOG.info("Starting application...")
init_database()
LOG.info("Starting cron scheduler...")
start_cron_scheduler(scheduler)
await start_cron_scheduler(scheduler)
scheduler.start()
LOG.info("Cron scheduler started successfully")
+14 -5
View File
@@ -6,9 +6,11 @@ from datetime import timedelta
# PDM
from fastapi import Depends, APIRouter, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.auth.schemas import User, Token, UserCreate
from api.backend.database.base import AsyncSessionLocal, get_db
from api.backend.auth.auth_utils import (
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user,
@@ -16,7 +18,7 @@ from api.backend.auth.auth_utils import (
get_password_hash,
create_access_token,
)
from api.backend.database.common import update
from api.backend.database.models import User as DatabaseUser
from api.backend.routers.handle_exceptions import handle_exceptions
auth_router = APIRouter()
@@ -26,8 +28,8 @@ LOG = logging.getLogger("Auth")
@auth_router.post("/auth/token", response_model=Token)
@handle_exceptions(logger=LOG)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = await authenticate_user(form_data.username, form_data.password)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
user = await authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
@@ -56,8 +58,15 @@ async def create_user(user: UserCreate):
user_dict["hashed_password"] = hashed_password
del user_dict["password"]
query = "INSERT INTO users (email, hashed_password, full_name) VALUES (?, ?, ?)"
_ = update(query, (user_dict["email"], hashed_password, user_dict["full_name"]))
async with AsyncSessionLocal() as session:
new_user = DatabaseUser(
email=user.email,
hashed_password=user_dict["hashed_password"],
full_name=user.full_name,
)
session.add(new_user)
await session.commit()
return user_dict
+24 -13
View File
@@ -8,12 +8,15 @@ from datetime import datetime, timedelta
from jose import JWTError, jwt
from dotenv import load_dotenv
from fastapi import Depends, HTTPException, status
from sqlalchemy import select
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.auth.schemas import User, UserInDB, TokenData
from api.backend.database.common import query
from api.backend.database.base import get_db
from api.backend.database.models import User as UserModel
LOG = logging.getLogger("Auth")
@@ -37,18 +40,24 @@ def get_password_hash(password: str):
return pwd_context.hash(password)
async def get_user(email: str):
user_query = "SELECT * FROM users WHERE email = ?"
user = query(user_query, (email,))[0]
async def get_user(session: AsyncSession, email: str) -> UserInDB | None:
stmt = select(UserModel).where(UserModel.email == email)
result = await session.execute(stmt)
user = result.scalars().first()
if not user:
return
return None
return UserInDB(**user)
return UserInDB(
email=str(user.email),
hashed_password=str(user.hashed_password),
full_name=str(user.full_name),
disabled=bool(user.disabled),
)
async def authenticate_user(email: str, password: str):
user = await get_user(email)
async def authenticate_user(email: str, password: str, db: AsyncSession):
user = await get_user(db, email)
if not user:
return False
@@ -74,7 +83,9 @@ def create_access_token(
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
async def get_current_user(
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
):
LOG.debug(f"Getting current user with token: {token}")
if not token:
@@ -82,7 +93,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
return EMPTY_USER
if len(token.split(".")) != 3:
LOG.error(f"Malformed token: {token}")
LOG.debug(f"Malformed token: {token}")
return EMPTY_USER
try:
@@ -117,7 +128,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
LOG.error(f"Exception occurred: {e}")
return EMPTY_USER
user = await get_user(email=token_data.email or "")
user = await get_user(db, email=token_data.email or "")
if user is None:
return EMPTY_USER
@@ -125,7 +136,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
return user
async def require_user(token: str = Depends(oauth2_scheme)):
async def require_user(db: AsyncSession, token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -150,7 +161,7 @@ async def require_user(token: str = Depends(oauth2_scheme)):
except JWTError:
raise credentials_exception
user = await get_user(email=token_data.email or "")
user = await get_user(db, email=token_data.email or "")
if user is None:
raise credentials_exception
+1 -1
View File
@@ -2,7 +2,7 @@
import os
from pathlib import Path
DATABASE_PATH = "data/database.db"
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///data/database.db")
RECORDINGS_DIR = Path("media/recordings")
RECORDINGS_ENABLED = os.getenv("RECORDINGS_ENABLED", "true").lower() == "true"
MEDIA_DIR = Path("media")
-5
View File
@@ -1,5 +0,0 @@
# LOCAL
from .common import insert, update, connect
from .schema import INIT_QUERY
__all__ = ["insert", "update", "INIT_QUERY", "connect"]
+26
View File
@@ -0,0 +1,26 @@
# STL
from typing import AsyncGenerator
# PDM
from sqlalchemy.orm import declarative_base
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
# LOCAL
from api.backend.constants import DATABASE_URL
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
AsyncSessionLocal = async_sessionmaker(
bind=engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
class_=AsyncSession,
)
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
yield session
-90
View File
@@ -1,90 +0,0 @@
# STL
import logging
import sqlite3
from typing import Any, Optional
# LOCAL
from api.backend.constants import DATABASE_PATH
from api.backend.database.utils import format_json, format_sql_row_to_python
LOG = logging.getLogger("Database")
def connect():
connection = sqlite3.connect(DATABASE_PATH)
connection.set_trace_callback(print)
cursor = connection.cursor()
return cursor
def insert(query: str, values: tuple[Any, ...]):
connection = sqlite3.connect(DATABASE_PATH)
cursor = connection.cursor()
copy = list(values)
format_json(copy)
try:
_ = cursor.execute(query, copy)
connection.commit()
except sqlite3.Error as e:
LOG.error(f"An error occurred: {e}")
raise e
finally:
cursor.close()
connection.close()
def query(query: str, values: Optional[tuple[Any, ...]] = None):
connection = sqlite3.connect(DATABASE_PATH)
connection.row_factory = sqlite3.Row
cursor = connection.cursor()
rows = []
try:
if values:
_ = cursor.execute(query, values)
else:
_ = cursor.execute(query)
rows = cursor.fetchall()
finally:
cursor.close()
connection.close()
formatted_rows: list[dict[str, Any]] = []
for row in rows:
row = dict(row)
formatted_row = format_sql_row_to_python(row)
formatted_rows.append(formatted_row)
return formatted_rows
def update(query: str, values: Optional[tuple[Any, ...]] = None):
connection = sqlite3.connect(DATABASE_PATH)
cursor = connection.cursor()
copy = None
if values:
copy = list(values)
format_json(copy)
try:
if copy:
res = cursor.execute(query, copy)
else:
res = cursor.execute(query)
connection.commit()
return res.rowcount
except sqlite3.Error as e:
LOG.error(f"An error occurred: {e}")
finally:
cursor.close()
connection.close()
return 0
+65
View File
@@ -0,0 +1,65 @@
# PDM
from sqlalchemy import JSON, Column, String, Boolean, DateTime, ForeignKey, func
from sqlalchemy.orm import relationship
# LOCAL
from api.backend.database.base import Base
class User(Base):
__tablename__ = "users"
email = Column(String(255), primary_key=True, nullable=False)
hashed_password = Column(String(255), nullable=False)
full_name = Column(String(255), nullable=True)
disabled = Column(Boolean, default=False)
jobs = relationship("Job", back_populates="user_obj", cascade="all, delete-orphan")
cron_jobs = relationship(
"CronJob", back_populates="user_obj", cascade="all, delete-orphan"
)
class Job(Base):
__tablename__ = "jobs"
id = Column(String(64), primary_key=True, nullable=False)
url = Column(String(2048), nullable=False)
elements = Column(JSON, nullable=False)
user = Column(String(255), ForeignKey("users.email"), nullable=True)
time_created = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
result = Column(JSON, nullable=False)
status = Column(String(50), nullable=False)
chat = Column(JSON, nullable=True)
job_options = Column(JSON, nullable=True)
agent_mode = Column(Boolean, default=False, nullable=False)
prompt = Column(String(1024), nullable=True)
favorite = Column(Boolean, default=False, nullable=False)
user_obj = relationship("User", back_populates="jobs")
cron_jobs = relationship(
"CronJob", back_populates="job_obj", cascade="all, delete-orphan"
)
class CronJob(Base):
__tablename__ = "cron_jobs"
id = Column(String(64), primary_key=True, nullable=False)
user_email = Column(String(255), ForeignKey("users.email"), nullable=False)
job_id = Column(String(64), ForeignKey("jobs.id"), nullable=False)
cron_expression = Column(String(255), nullable=False)
time_created = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
time_updated = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
user_obj = relationship("User", back_populates="cron_jobs")
job_obj = relationship("Job", back_populates="cron_jobs")
-4
View File
@@ -1,4 +0,0 @@
# LOCAL
from .job.job_queries import DELETE_JOB_QUERY, JOB_INSERT_QUERY
__all__ = ["JOB_INSERT_QUERY", "DELETE_JOB_QUERY"]
+47 -48
View File
@@ -2,62 +2,61 @@
import logging
from typing import Any
# PDM
from sqlalchemy import delete as sql_delete
from sqlalchemy import select
from sqlalchemy import update as sql_update
# LOCAL
from api.backend.database.utils import format_list_for_query
from api.backend.database.common import query, insert, update
JOB_INSERT_QUERY = """
INSERT INTO jobs
(id, url, elements, user, time_created, result, status, chat, job_options, agent_mode, prompt)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
DELETE_JOB_QUERY = """
DELETE FROM jobs WHERE id IN ()
"""
from api.backend.database.base import AsyncSessionLocal
from api.backend.database.models import Job
LOG = logging.getLogger("Database")
def insert_job(item: dict[str, Any]) -> None:
insert(
JOB_INSERT_QUERY,
(
item["id"],
item["url"],
item["elements"],
item["user"],
item["time_created"],
item["result"],
item["status"],
item["chat"],
item["job_options"],
item["agent_mode"],
item["prompt"],
),
)
LOG.info(f"Inserted item: {item}")
async def insert_job(item: dict[str, Any]) -> None:
async with AsyncSessionLocal() as session:
job = Job(
id=item["id"],
url=item["url"],
elements=item["elements"],
user=item["user"],
time_created=item["time_created"],
result=item["result"],
status=item["status"],
chat=item["chat"],
job_options=item["job_options"],
agent_mode=item["agent_mode"],
prompt=item["prompt"],
)
session.add(job)
await session.commit()
LOG.info(f"Inserted item: {item}")
async def get_queued_job():
queued_job_query = (
"SELECT * FROM jobs WHERE status = 'Queued' ORDER BY time_created DESC LIMIT 1"
)
res = query(queued_job_query)
LOG.info(f"Got queued job: {res}")
return res[0] if res else None
async with AsyncSessionLocal() as session:
stmt = (
select(Job)
.where(Job.status == "Queued")
.order_by(Job.time_created.desc())
.limit(1)
)
result = await session.execute(stmt)
job = result.scalars().first()
LOG.info(f"Got queued job: {job}")
return job
async def update_job(ids: list[str], updates: dict[str, Any]):
if not updates:
return
set_clause = ", ".join(f"{field} = ?" for field in updates.keys())
query = f"UPDATE jobs SET {set_clause} WHERE id IN {format_list_for_query(ids)}"
values = list(updates.values()) + ids
res = update(query, tuple(values))
LOG.debug(f"Updated job: {res}")
async with AsyncSessionLocal() as session:
stmt = sql_update(Job).where(Job.id.in_(ids)).values(**updates)
result = await session.execute(stmt)
await session.commit()
LOG.debug(f"Updated job count: {result.rowcount}")
async def delete_jobs(jobs: list[str]):
@@ -65,9 +64,9 @@ async def delete_jobs(jobs: list[str]):
LOG.info("No jobs to delete.")
return False
query = f"DELETE FROM jobs WHERE id IN {format_list_for_query(jobs)}"
res = update(query, tuple(jobs))
LOG.info(f"Deleted jobs: {res}")
return res
async with AsyncSessionLocal() as session:
stmt = sql_delete(Job).where(Job.id.in_(jobs))
result = await session.execute(stmt)
await session.commit()
LOG.info(f"Deleted jobs count: {result.rowcount}")
return result.rowcount
@@ -1,41 +1,43 @@
# PDM
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.database.common import query
from api.backend.database.models import Job
async def average_elements_per_link(user: str):
job_query = """
SELECT
DATE(time_created) AS date,
AVG(json_array_length(elements)) AS average_elements,
COUNT(*) AS count
FROM
jobs
WHERE
status = 'Completed' AND user = ?
GROUP BY
DATE(time_created)
ORDER BY
date ASC;
"""
results = query(job_query, (user,))
async def average_elements_per_link(session: AsyncSession, user_email: str):
date_func = func.date(Job.time_created)
return results
stmt = (
select(
date_func.label("date"),
func.avg(func.json_array_length(Job.elements)).label("average_elements"),
func.count().label("count"),
)
.where(Job.status == "Completed", Job.user == user_email)
.group_by(date_func)
.order_by("date")
)
result = await session.execute(stmt)
rows = result.all()
return [dict(row._mapping) for row in rows]
async def get_jobs_per_day(user: str):
job_query = """
SELECT
DATE(time_created) AS date,
COUNT(*) AS job_count
FROM
jobs
WHERE
status = 'Completed' AND user = ?
GROUP BY
DATE(time_created)
ORDER BY
date ASC;
"""
results = query(job_query, (user,))
async def get_jobs_per_day(session: AsyncSession, user_email: str):
date_func = func.date(Job.time_created)
return results
stmt = (
select(
date_func.label("date"),
func.count().label("job_count"),
)
.where(Job.status == "Completed", Job.user == user_email)
.group_by(date_func)
.order_by("date")
)
result = await session.execute(stmt)
rows = result.all()
return [dict(row._mapping) for row in rows]
-3
View File
@@ -1,3 +0,0 @@
from .schema import INIT_QUERY
__all__ = ["INIT_QUERY"]
-34
View File
@@ -1,34 +0,0 @@
INIT_QUERY = """
CREATE TABLE IF NOT EXISTS jobs (
id STRING PRIMARY KEY NOT NULL,
url STRING NOT NULL,
elements JSON NOT NULL,
user STRING,
time_created DATETIME NOT NULL,
result JSON NOT NULL,
status STRING NOT NULL,
chat JSON,
job_options JSON
);
CREATE TABLE IF NOT EXISTS users (
email STRING PRIMARY KEY NOT NULL,
hashed_password STRING NOT NULL,
full_name STRING,
disabled BOOLEAN
);
CREATE TABLE IF NOT EXISTS cron_jobs (
id STRING PRIMARY KEY NOT NULL,
user_email STRING NOT NULL,
job_id STRING NOT NULL,
cron_expression STRING NOT NULL,
time_created DATETIME NOT NULL,
time_updated DATETIME NOT NULL,
FOREIGN KEY (job_id) REFERENCES jobs(id)
);
ALTER TABLE jobs ADD COLUMN agent_mode BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE jobs ADD COLUMN prompt STRING;
ALTER TABLE jobs ADD COLUMN favorite BOOLEAN NOT NULL DEFAULT FALSE;
"""
+33 -46
View File
@@ -1,6 +1,8 @@
# STL
import logging
import sqlite3
# PDM
from sqlalchemy.exc import IntegrityError
# LOCAL
from api.backend.constants import (
@@ -9,61 +11,46 @@ from api.backend.constants import (
DEFAULT_USER_PASSWORD,
DEFAULT_USER_FULL_NAME,
)
from api.backend.database.base import Base, AsyncSessionLocal, engine
from api.backend.auth.auth_utils import get_password_hash
from api.backend.database.common import insert, connect
from api.backend.database.schema import INIT_QUERY
from api.backend.database.models import User
LOG = logging.getLogger("Database")
async def init_database():
LOG.info("Creating database schema...")
def execute_startup_query():
cursor = connect()
for query in INIT_QUERY.strip().split(";"):
query = query.strip()
if not query:
continue
try:
LOG.info(f"Executing query: {query}")
_ = cursor.execute(query)
except sqlite3.OperationalError as e:
if "duplicate column name" in str(e).lower():
LOG.warning(f"Skipping duplicate column error: {e}")
continue
else:
LOG.error(f"Error executing query: {query}")
raise
cursor.close()
def init_database():
execute_startup_query()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
if not REGISTRATION_ENABLED:
default_user_email = DEFAULT_USER_EMAIL
default_user_password = DEFAULT_USER_PASSWORD
default_user_full_name = DEFAULT_USER_FULL_NAME
if (
not default_user_email
or not default_user_password
or not default_user_full_name
):
LOG.error(
"DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD, or DEFAULT_USER_FULL_NAME is not set!"
)
if not (default_user_email and default_user_password and default_user_full_name):
LOG.error("DEFAULT_USER_* env vars are not set!")
exit(1)
query = "INSERT INTO users (email, hashed_password, full_name) VALUES (?, ?, ?)"
_ = insert(
query,
(
default_user_email,
get_password_hash(default_user_password),
default_user_full_name,
),
)
async with AsyncSessionLocal() as session:
user = await session.get(User, default_user_email)
if user:
LOG.info("Default user already exists. Skipping creation.")
return
LOG.info("Creating default user...")
new_user = User(
email=default_user_email,
hashed_password=get_password_hash(default_user_password),
full_name=default_user_full_name,
disabled=False,
)
try:
session.add(new_user)
await session.commit()
LOG.info(f"Created default user: {default_user_email}")
except IntegrityError as e:
await session.rollback()
LOG.warning(f"Could not create default user (already exists?): {e}")
+7
View File
@@ -1,6 +1,7 @@
# STL
import json
from typing import Any
from datetime import datetime
def format_list_for_query(ids: list[str]):
@@ -28,3 +29,9 @@ def format_json(items: list[Any]):
if isinstance(item, (dict, list)):
formatted_item = json.dumps(item)
items[idx] = formatted_item
def parse_datetime(dt_str: str) -> datetime:
if dt_str.endswith("Z"):
dt_str = dt_str.replace("Z", "+00:00") # valid ISO format for UTC
return datetime.fromisoformat(dt_str)
@@ -2,83 +2,74 @@
import uuid
import logging
import datetime
from typing import Any
from typing import Any, List
# PDM
from sqlalchemy import select
from apscheduler.triggers.cron import CronTrigger
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.asyncio import AsyncIOScheduler
# LOCAL
from api.backend.job import insert as insert_job
from api.backend.schemas.cron import CronJob
from api.backend.database.common import query, insert
from api.backend.database.base import AsyncSessionLocal
from api.backend.database.models import Job, CronJob
LOG = logging.getLogger("Cron")
def insert_cron_job(cron_job: CronJob):
query = """
INSERT INTO cron_jobs (id, user_email, job_id, cron_expression, time_created, time_updated)
VALUES (?, ?, ?, ?, ?, ?)
"""
values = (
cron_job.id,
cron_job.user_email,
cron_job.job_id,
cron_job.cron_expression,
cron_job.time_created,
cron_job.time_updated,
)
insert(query, values)
async def insert_cron_job(cron_job: CronJob) -> bool:
async with AsyncSessionLocal() as session:
session.add(cron_job)
await session.commit()
return True
def delete_cron_job(id: str, user_email: str):
query = """
DELETE FROM cron_jobs
WHERE id = ? AND user_email = ?
"""
values = (id, user_email)
insert(query, values)
async def delete_cron_job(id: str, user_email: str) -> bool:
async with AsyncSessionLocal() as session:
stmt = select(CronJob).where(CronJob.id == id, CronJob.user_email == user_email)
result = await session.execute(stmt)
cron_job = result.scalars().first()
if cron_job:
await session.delete(cron_job)
await session.commit()
return True
def get_cron_jobs(user_email: str):
cron_jobs = query("SELECT * FROM cron_jobs WHERE user_email = ?", (user_email,))
return cron_jobs
async def get_cron_jobs(user_email: str) -> List[CronJob]:
async with AsyncSessionLocal() as session:
stmt = select(CronJob).where(CronJob.user_email == user_email)
result = await session.execute(stmt)
return list(result.scalars().all())
def get_all_cron_jobs():
cron_jobs = query("SELECT * FROM cron_jobs")
return cron_jobs
async def get_all_cron_jobs() -> List[CronJob]:
async with AsyncSessionLocal() as session:
stmt = select(CronJob)
result = await session.execute(stmt)
return list(result.scalars().all())
def insert_job_from_cron_job(job: dict[str, Any]):
insert_job(
{
**job,
"id": uuid.uuid4().hex,
"status": "Queued",
"result": "",
"chat": None,
"time_created": datetime.datetime.now(),
"time_updated": datetime.datetime.now(),
}
)
async def insert_job_from_cron_job(job: dict[str, Any]):
async with AsyncSessionLocal() as session:
await insert_job(
{
**job,
"id": uuid.uuid4().hex,
"status": "Queued",
"result": "",
"chat": None,
"time_created": datetime.datetime.now(datetime.timezone.utc),
"time_updated": datetime.datetime.now(datetime.timezone.utc),
},
session,
)
def get_cron_job_trigger(cron_expression: str):
expression_parts = cron_expression.split()
if len(expression_parts) != 5:
print(f"Invalid cron expression: {cron_expression}")
LOG.warning(f"Invalid cron expression: {cron_expression}")
return None
minute, hour, day, month, day_of_week = expression_parts
@@ -88,19 +79,37 @@ def get_cron_job_trigger(cron_expression: str):
)
def start_cron_scheduler(scheduler: BackgroundScheduler):
cron_jobs = get_all_cron_jobs()
async def start_cron_scheduler(scheduler: AsyncIOScheduler):
async with AsyncSessionLocal() as session:
stmt = select(CronJob)
result = await session.execute(stmt)
cron_jobs = result.scalars().all()
LOG.info(f"Cron jobs: {cron_jobs}")
LOG.info(f"Cron jobs: {cron_jobs}")
for job in cron_jobs:
queried_job = query("SELECT * FROM jobs WHERE id = ?", (job["job_id"],))
for cron_job in cron_jobs:
stmt = select(Job).where(Job.id == cron_job.job_id)
result = await session.execute(stmt)
queried_job = result.scalars().first()
LOG.info(f"Adding job: {queried_job}")
LOG.info(f"Adding job: {queried_job}")
scheduler.add_job(
insert_job_from_cron_job,
get_cron_job_trigger(job["cron_expression"]),
id=job["id"],
args=[queried_job[0]],
)
trigger = get_cron_job_trigger(cron_job.cron_expression) # type: ignore
if not trigger:
continue
job_dict = (
{
c.key: getattr(queried_job, c.key)
for c in queried_job.__table__.columns
}
if queried_job
else {}
)
scheduler.add_job(
insert_job_from_cron_job,
trigger,
id=cron_job.id,
args=[job_dict],
)
+69 -44
View File
@@ -3,18 +3,22 @@ import logging
import datetime
from typing import Any
# PDM
from sqlalchemy import delete as sql_delete
from sqlalchemy import select
from sqlalchemy import update as sql_update
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.database.utils import format_list_for_query
from api.backend.database.common import query as common_query
from api.backend.database.common import insert as common_insert
from api.backend.database.common import update as common_update
from api.backend.database.queries.job.job_queries import JOB_INSERT_QUERY
from api.backend.database.base import AsyncSessionLocal
from api.backend.database.models import Job
LOG = logging.getLogger("Job")
async def insert(item: dict[str, Any]) -> None:
if check_for_job_completion(item["id"]):
async def insert(item: dict[str, Any], db: AsyncSession) -> None:
existing = await db.get(Job, item["id"])
if existing:
await multi_field_update_job(
item["id"],
{
@@ -24,57 +28,76 @@ async def insert(item: dict[str, Any]) -> None:
"elements": item["elements"],
"status": "Queued",
"result": [],
"time_created": datetime.datetime.now().isoformat(),
"time_created": datetime.datetime.now(datetime.timezone.utc),
"chat": None,
},
db,
)
return
common_insert(
JOB_INSERT_QUERY,
(
item["id"],
item["url"],
item["elements"],
item["user"],
item["time_created"],
item["result"],
item["status"],
item["chat"],
item["job_options"],
item["agent_mode"],
item["prompt"],
),
job = Job(
id=item["id"],
url=item["url"],
elements=item["elements"],
user=item["user"],
time_created=datetime.datetime.now(datetime.timezone.utc),
result=item["result"],
status=item["status"],
chat=item["chat"],
job_options=item["job_options"],
agent_mode=item["agent_mode"],
prompt=item["prompt"],
)
db.add(job)
await db.commit()
LOG.debug(f"Inserted item: {item}")
def check_for_job_completion(id: str) -> dict[str, Any]:
query = f"SELECT * FROM jobs WHERE id = ?"
res = common_query(query, (id,))
return res[0] if res else {}
async def check_for_job_completion(id: str) -> dict[str, Any]:
async with AsyncSessionLocal() as session:
job = await session.get(Job, id)
return job.__dict__ if job else {}
async def get_queued_job():
query = (
"SELECT * FROM jobs WHERE status = 'Queued' ORDER BY time_created DESC LIMIT 1"
)
res = common_query(query)
LOG.debug(f"Got queued job: {res}")
return res[0] if res else None
async with AsyncSessionLocal() as session:
stmt = (
select(Job)
.where(Job.status == "Queued")
.order_by(Job.time_created.desc())
.limit(1)
)
result = await session.execute(stmt)
job = result.scalars().first()
LOG.debug(f"Got queued job: {job}")
return job.__dict__ if job else None
async def update_job(ids: list[str], field: str, value: Any):
query = f"UPDATE jobs SET {field} = ? WHERE id IN {format_list_for_query(ids)}"
res = common_update(query, tuple([value] + ids))
LOG.debug(f"Updated job: {res}")
async with AsyncSessionLocal() as session:
stmt = sql_update(Job).where(Job.id.in_(ids)).values({field: value})
res = await session.execute(stmt)
await session.commit()
LOG.debug(f"Updated job count: {res.rowcount}")
async def multi_field_update_job(id: str, fields: dict[str, Any]):
query = f"UPDATE jobs SET {', '.join(f'{field} = ?' for field in fields.keys())} WHERE id = ?"
res = common_update(query, tuple(list(fields.values()) + [id]))
LOG.debug(f"Updated job: {res}")
async def multi_field_update_job(
id: str, fields: dict[str, Any], session: AsyncSession | None = None
):
close_session = False
if not session:
session = AsyncSessionLocal()
close_session = True
try:
stmt = sql_update(Job).where(Job.id == id).values(**fields)
await session.execute(stmt)
await session.commit()
LOG.debug(f"Updated job {id} fields: {fields}")
finally:
if close_session:
await session.close()
async def delete_jobs(jobs: list[str]):
@@ -82,7 +105,9 @@ async def delete_jobs(jobs: list[str]):
LOG.debug("No jobs to delete.")
return False
query = f"DELETE FROM jobs WHERE id IN {format_list_for_query(jobs)}"
res = common_update(query, tuple(jobs))
return res > 0
async with AsyncSessionLocal() as session:
stmt = sql_delete(Job).where(Job.id.in_(jobs))
res = await session.execute(stmt)
await session.commit()
LOG.debug(f"Deleted jobs: {res.rowcount}")
return res.rowcount > 0
+63 -33
View File
@@ -8,8 +8,10 @@ from io import StringIO
# PDM
from fastapi import Depends, APIRouter
from sqlalchemy import select
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from apscheduler.triggers.cron import CronTrigger # type: ignore
# LOCAL
@@ -18,10 +20,12 @@ from api.backend.constants import MEDIA_DIR, MEDIA_TYPES, RECORDINGS_DIR
from api.backend.scheduler import scheduler
from api.backend.schemas.job import Job, UpdateJobs, DownloadJob, DeleteScrapeJobs
from api.backend.auth.schemas import User
from api.backend.schemas.cron import CronJob, DeleteCronJob
from api.backend.database.utils import format_list_for_query
from api.backend.schemas.cron import CronJob as PydanticCronJob
from api.backend.schemas.cron import DeleteCronJob
from api.backend.database.base import get_db
from api.backend.auth.auth_utils import get_current_user
from api.backend.database.common import query
from api.backend.database.models import Job as DatabaseJob
from api.backend.database.models import CronJob
from api.backend.job.utils.text_utils import clean_text
from api.backend.job.models.job_options import FetchOptions
from api.backend.routers.handle_exceptions import handle_exceptions
@@ -49,14 +53,14 @@ async def update(update_jobs: UpdateJobs, _: User = Depends(get_current_user)):
@job_router.post("/submit-scrape-job")
@handle_exceptions(logger=LOG)
async def submit_scrape_job(job: Job):
async def submit_scrape_job(job: Job, db: AsyncSession = Depends(get_db)):
LOG.info(f"Recieved job: {job}")
if not job.id:
job.id = uuid.uuid4().hex
job_dict = job.model_dump()
await insert(job_dict)
await insert(job_dict, db)
return JSONResponse(
content={"id": job.id, "message": "Job submitted successfully."}
@@ -66,34 +70,49 @@ async def submit_scrape_job(job: Job):
@job_router.post("/retrieve-scrape-jobs")
@handle_exceptions(logger=LOG)
async def retrieve_scrape_jobs(
fetch_options: FetchOptions, user: User = Depends(get_current_user)
fetch_options: FetchOptions,
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
LOG.info(f"Retrieving jobs for account: {user.email}")
ATTRIBUTES = "chat" if fetch_options.chat else "*"
job_query = (
f"SELECT {ATTRIBUTES} FROM jobs WHERE user = ? ORDER BY time_created ASC"
LOG.info(
f"Retrieving jobs for account: {user.email if user.email else 'Guest User'}"
)
results = query(job_query, (user.email,))
return JSONResponse(content=jsonable_encoder(results[::-1]))
if fetch_options.chat:
stmt = select(DatabaseJob.chat).filter(DatabaseJob.user == user.email)
else:
stmt = select(DatabaseJob).filter(DatabaseJob.user == user.email)
results = await db.execute(stmt)
rows = results.all() if fetch_options.chat else results.scalars().all()
return JSONResponse(content=jsonable_encoder(rows[::-1]))
@job_router.get("/job/{id}")
@handle_exceptions(logger=LOG)
async def job(id: str, user: User = Depends(get_current_user)):
async def job(
id: str, user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
LOG.info(f"Retrieving jobs for account: {user.email}")
job_query = "SELECT * FROM jobs WHERE user = ? AND id = ?"
results = query(job_query, (user.email, id))
return JSONResponse(content=jsonable_encoder(results))
stmt = select(DatabaseJob).filter(
DatabaseJob.user == user.email, DatabaseJob.id == id
)
results = await db.execute(stmt)
return JSONResponse(
content=jsonable_encoder([job.__dict__ for job in results.scalars().all()])
)
@job_router.post("/download")
@handle_exceptions(logger=LOG)
async def download(download_job: DownloadJob):
async def download(download_job: DownloadJob, db: AsyncSession = Depends(get_db)):
LOG.info(f"Downloading job with ids: {download_job.ids}")
job_query = (
f"SELECT * FROM jobs WHERE id IN {format_list_for_query(download_job.ids)}"
)
results = query(job_query, tuple(download_job.ids))
stmt = select(DatabaseJob).where(DatabaseJob.id.in_(download_job.ids))
result = await db.execute(stmt)
results = [job.__dict__ for job in result.scalars().all()]
if download_job.job_format == "csv":
csv_buffer = StringIO()
@@ -151,10 +170,12 @@ async def download(download_job: DownloadJob):
@job_router.get("/job/{id}/convert-to-csv")
@handle_exceptions(logger=LOG)
async def convert_to_csv(id: str):
job_query = f"SELECT * FROM jobs WHERE id = ?"
results = query(job_query, (id,))
return JSONResponse(content=clean_job_format(results))
async def convert_to_csv(id: str, db: AsyncSession = Depends(get_db)):
stmt = select(DatabaseJob).filter(DatabaseJob.id == id)
results = await db.execute(stmt)
jobs = results.scalars().all()
return JSONResponse(content=clean_job_format([job.__dict__ for job in jobs]))
@job_router.post("/delete-scrape-jobs")
@@ -170,25 +191,34 @@ async def delete(delete_scrape_jobs: DeleteScrapeJobs):
@job_router.post("/schedule-cron-job")
@handle_exceptions(logger=LOG)
async def schedule_cron_job(cron_job: CronJob):
async def schedule_cron_job(
cron_job: PydanticCronJob,
db: AsyncSession = Depends(get_db),
):
if not cron_job.id:
cron_job.id = uuid.uuid4().hex
now = datetime.datetime.now()
if not cron_job.time_created:
cron_job.time_created = datetime.datetime.now()
cron_job.time_created = now
if not cron_job.time_updated:
cron_job.time_updated = datetime.datetime.now()
cron_job.time_updated = now
insert_cron_job(cron_job)
await insert_cron_job(CronJob(**cron_job.model_dump()))
queried_job = query("SELECT * FROM jobs WHERE id = ?", (cron_job.job_id,))
stmt = select(DatabaseJob).where(DatabaseJob.id == cron_job.job_id)
result = await db.execute(stmt)
queried_job = result.scalars().first()
if not queried_job:
return JSONResponse(status_code=404, content={"error": "Related job not found"})
scheduler.add_job(
insert_job_from_cron_job,
get_cron_job_trigger(cron_job.cron_expression),
id=cron_job.id,
args=[queried_job[0]],
args=[queried_job],
)
return JSONResponse(content={"message": "Cron job scheduled successfully."})
@@ -202,7 +232,7 @@ async def delete_cron_job_request(request: DeleteCronJob):
content={"error": "Cron job id is required."}, status_code=400
)
delete_cron_job(request.id, request.user_email)
await delete_cron_job(request.id, request.user_email)
scheduler.remove_job(request.id)
return JSONResponse(content={"message": "Cron job deleted successfully."})
@@ -211,7 +241,7 @@ async def delete_cron_job_request(request: DeleteCronJob):
@job_router.get("/cron-jobs")
@handle_exceptions(logger=LOG)
async def get_cron_jobs_request(user: User = Depends(get_current_user)):
cron_jobs = get_cron_jobs(user.email)
cron_jobs = await get_cron_jobs(user.email)
return JSONResponse(content=jsonable_encoder(cron_jobs))
+3 -1
View File
@@ -28,7 +28,9 @@ def clean_job_format(jobs: list[dict[str, Any]]) -> dict[str, Any]:
"xpath": value.get("xpath", ""),
"text": text,
"user": job.get("user", ""),
"time_created": job.get("time_created", ""),
"time_created": job.get(
"time_created", ""
).isoformat(),
}
)
+2 -2
View File
@@ -1,4 +1,4 @@
# PDM
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.asyncio import AsyncIOScheduler
scheduler = BackgroundScheduler()
scheduler = AsyncIOScheduler()
+10 -4
View File
@@ -3,9 +3,11 @@ import logging
# PDM
from fastapi import Depends, APIRouter
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.auth.schemas import User
from api.backend.database.base import get_db
from api.backend.auth.auth_utils import get_current_user
from api.backend.routers.handle_exceptions import handle_exceptions
from api.backend.database.queries.statistics.statistic_queries import (
@@ -20,12 +22,16 @@ stats_router = APIRouter()
@stats_router.get("/statistics/get-average-element-per-link")
@handle_exceptions(logger=LOG)
async def get_average_element_per_link(user: User = Depends(get_current_user)):
return await average_elements_per_link(user.email)
async def get_average_element_per_link(
user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
return await average_elements_per_link(db, user.email)
@stats_router.get("/statistics/get-average-jobs-per-day")
@handle_exceptions(logger=LOG)
async def average_jobs_per_day(user: User = Depends(get_current_user)):
data = await get_jobs_per_day(user.email)
async def average_jobs_per_day(
user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
data = await get_jobs_per_day(db, user.email)
return data
+69 -24
View File
@@ -1,15 +1,21 @@
# STL
import os
import sqlite3
from typing import Generator
from unittest.mock import patch
import asyncio
from typing import Any, Generator, AsyncGenerator
# PDM
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from proxy import Proxy
from sqlalchemy import text
from sqlalchemy.pool import NullPool
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
# LOCAL
from api.backend.database.schema import INIT_QUERY
from api.backend.app import app
from api.backend.database.base import get_db
from api.backend.database.models import Base
from api.backend.tests.constants import TEST_DB_PATH
@@ -21,18 +27,6 @@ def running_proxy():
proxy.shutdown()
@pytest.fixture(scope="session", autouse=True)
def patch_database_path():
with patch("api.backend.database.common.DATABASE_PATH", TEST_DB_PATH):
yield
@pytest.fixture(scope="session", autouse=True)
def patch_recordings_enabled():
with patch("api.backend.job.scraping.scraping.RECORDINGS_ENABLED", False):
yield
@pytest.fixture(scope="session")
def test_db_path() -> str:
return TEST_DB_PATH
@@ -46,18 +40,69 @@ def test_db(test_db_path: str) -> Generator[str, None, None]:
if os.path.exists(test_db_path):
os.remove(test_db_path)
conn = sqlite3.connect(test_db_path)
cursor = conn.cursor()
# Create async engine for test database
test_db_url = f"sqlite+aiosqlite:///{test_db_path}"
engine = create_async_engine(test_db_url, echo=False)
for query in INIT_QUERY.strip().split(";"):
query = query.strip()
if query:
cursor.execute(query)
async def setup_db():
async with engine.begin() as conn:
# Create tables
# LOCAL
from api.backend.database.models import Base
conn.commit()
conn.close()
await conn.run_sync(Base.metadata.create_all)
# Run setup
asyncio.run(setup_db())
yield test_db_path
if os.path.exists(test_db_path):
os.remove(test_db_path)
@pytest_asyncio.fixture(scope="session")
async def test_engine():
test_db_url = f"sqlite+aiosqlite:///{TEST_DB_PATH}"
engine = create_async_engine(test_db_url, poolclass=NullPool)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture(scope="function")
async def db_session(test_engine: Any) -> AsyncGenerator[AsyncSession, None]:
async_session = async_sessionmaker(
bind=test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
try:
yield session
finally:
# Truncate all tables after each test
for table in reversed(Base.metadata.sorted_tables):
await session.execute(text(f"DELETE FROM {table.name}"))
await session.commit()
@pytest.fixture()
def override_get_db(db_session: AsyncSession):
async def _override() -> AsyncGenerator[AsyncSession, None]:
yield db_session
return _override
@pytest_asyncio.fixture()
async def client(override_get_db: Any) -> AsyncGenerator[AsyncClient, None]:
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
+45 -25
View File
@@ -1,45 +1,65 @@
# STL
from unittest.mock import AsyncMock, patch
import random
from datetime import datetime, timezone
# PDM
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.app import app
from api.backend.schemas.job import DownloadJob
from api.backend.tests.factories.job_factory import create_completed_job
from api.backend.database.models import Job
client = TestClient(app)
mocked_job = create_completed_job().model_dump()
mock_results = [mocked_job]
mocked_random_int = 123456
@pytest.mark.asyncio
@patch("api.backend.job.job_router.query")
@patch("api.backend.job.job_router.random.randint")
async def test_download(mock_randint: AsyncMock, mock_query: AsyncMock):
# Ensure the mock returns immediately
mock_query.return_value = mock_results
mock_randint.return_value = mocked_random_int
async def test_download(client: AsyncClient, db_session: AsyncSession):
# Insert a test job into the DB
job_id = "test-job-id"
test_job = Job(
id=job_id,
url="https://example.com",
elements=[],
user="test@example.com",
time_created=datetime.now(timezone.utc),
result=[
{
"https://example.com": {
"element_name": [{"xpath": "//div", "text": "example"}]
}
}
],
status="Completed",
chat=None,
job_options={},
agent_mode=False,
prompt="",
favorite=False,
)
db_session.add(test_job)
await db_session.commit()
# Create a DownloadJob instance
download_job = DownloadJob(ids=[mocked_job["id"]], job_format="csv")
# Force predictable randint
random.seed(0)
# Make a POST request to the /download endpoint
response = client.post("/download", json=download_job.model_dump())
# Build request
download_job = DownloadJob(ids=[job_id], job_format="csv")
response = await client.post("/download", json=download_job.model_dump())
# Assertions
assert response.status_code == 200
assert response.headers["Content-Disposition"] == "attachment; filename=export.csv"
# Check the content of the CSV
# Validate CSV contents
csv_content = response.content.decode("utf-8")
expected_csv = (
f'"id","url","element_name","xpath","text","user","time_created"\r\n'
f'"{mocked_job["id"]}-{mocked_random_int}","https://example.com","element_name","//div","example",'
f'"{mocked_job["user"]}","{mocked_job["time_created"]}"\r\n'
lines = csv_content.strip().split("\n")
assert (
lines[0].strip()
== '"id","url","element_name","xpath","text","user","time_created"'
)
assert csv_content == expected_csv
assert '"https://example.com"' in lines[1]
assert '"element_name"' in lines[1]
assert '"//div"' in lines[1]
assert '"example"' in lines[1]
+24 -17
View File
@@ -5,15 +5,17 @@ from datetime import datetime
# PDM
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from fastapi.testclient import TestClient
from playwright.async_api import Route, Cookie, async_playwright
from sqlalchemy.ext.asyncio import AsyncSession
# LOCAL
from api.backend.app import app
from api.backend.job.models import Proxy, Element, JobOptions
from api.backend.schemas.job import Job
from api.backend.database.common import query
from api.backend.job.scraping.scraping import scrape
from api.backend.database.models import Job as JobModel
from api.backend.job.scraping.add_custom import add_custom_items
logging.basicConfig(level=logging.DEBUG)
@@ -68,7 +70,7 @@ async def test_add_custom_items():
@pytest.mark.asyncio
async def test_proxies():
async def test_proxies(client: AsyncClient, db_session: AsyncSession):
job = Job(
url="https://example.com",
elements=[Element(xpath="//div", name="test")],
@@ -84,14 +86,22 @@ async def test_proxies():
time_created=datetime.now().isoformat(),
)
response = client.post("/submit-scrape-job", json=job.model_dump())
response = await client.post("/submit-scrape-job", json=job.model_dump())
assert response.status_code == 200
jobs = query("SELECT * FROM jobs")
job = jobs[0]
stmt = select(JobModel)
result = await db_session.execute(stmt)
jobs = result.scalars().all()
assert job is not None
assert job["job_options"]["proxies"] == [
assert len(jobs) > 0
job_from_db = jobs[0]
job_dict = job_from_db.__dict__
job_dict.pop("_sa_instance_state", None)
assert job_dict is not None
print(job_dict)
assert job_dict["job_options"]["proxies"] == [
{
"server": "127.0.0.1:8080",
"username": "user",
@@ -99,12 +109,9 @@ async def test_proxies():
}
]
response = await scrape(
id=job["id"],
url=job["url"],
xpaths=[Element(**e) for e in job["elements"]],
job_options=job["job_options"],
)
example_response = response[0]["https://example.com/"]
assert example_response is not {}
# Verify the job was stored correctly in the database
assert job_dict["url"] == "https://example.com"
assert job_dict["status"] == "Queued"
assert len(job_dict["elements"]) == 1
assert job_dict["elements"][0]["xpath"] == "//div"
assert job_dict["elements"][0]["name"] == "test"
-3
View File
@@ -12,7 +12,6 @@ from api.backend.job import update_job, get_queued_job
from api.backend.job.models import Element
from api.backend.worker.logger import LOG
from api.backend.ai.agent.agent import scrape_with_agent
from api.backend.database.startup import init_database
from api.backend.worker.constants import (
TO,
EMAIL,
@@ -124,8 +123,6 @@ async def process_job():
async def main():
LOG.info("Starting job worker...")
init_database()
RECORDINGS_DIR.mkdir(parents=True, exist_ok=True)
while True:
+4 -1
View File
@@ -3,7 +3,7 @@ FROM python:3.10.12-slim as pybuilder
RUN apt-get update && \
apt-get install -y curl && \
apt-get install -y x11vnc xvfb uvicorn wget gnupg supervisor libgl1 libglx-mesa0 libglx0 vainfo libva-dev libva-glx2 libva-drm2 ffmpeg && \
apt-get install -y x11vnc xvfb uvicorn wget gnupg supervisor libgl1 libglx-mesa0 libglx0 vainfo libva-dev libva-glx2 libva-drm2 ffmpeg pkg-config default-libmysqlclient-dev gcc && \
curl -LsSf https://astral.sh/uv/install.sh | sh && \
apt-get remove -y curl && \
apt-get autoremove -y && \
@@ -37,6 +37,9 @@ RUN touch /project/app/data/database.db
EXPOSE 5900
COPY alembic /project/app/alembic
COPY alembic.ini /project/app/alembic.ini
COPY start.sh /project/app/start.sh
CMD [ "supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf" ]
Generated
+621 -857
View File
File diff suppressed because it is too large Load Diff
+8 -2
View File
@@ -12,7 +12,7 @@ dependencies = [
"asyncio>=3.4.3",
"aiohttp>=3.9.5",
"bs4>=0.0.2",
"lxml[html_clean]>=5.2.2",
"lxml>=5.2.2",
"lxml-stubs>=0.5.1",
"fake-useragent>=1.5.1",
"requests-html>=0.10.0",
@@ -24,7 +24,6 @@ dependencies = [
"python-keycloak>=4.2.0",
"fastapi-keycloak>=1.0.11",
"pymongo>=4.8.0",
"motor[asyncio]>=3.5.0",
"python-jose[cryptography]>=3.3.0",
"passlib[bcrypt]>=1.7.4",
"selenium-wire>=5.1.0",
@@ -44,6 +43,13 @@ dependencies = [
"html2text>=2025.4.15",
"proxy-py>=2.4.10",
"browserforge==1.2.1",
"sqlalchemy>=2.0.41",
"aiosqlite>=0.21.0",
"alembic>=1.16.4",
"asyncpg>=0.30.0",
"aiomysql>=0.2.0",
"psycopg2-binary>=2.9.10",
"mysqlclient>=2.2.7",
]
requires-python = ">=3.10"
readme = "README.md"
@@ -1,6 +1,7 @@
export const download = async (ids: string[], jobFormat: string) => {
const response = await fetch("/api/download", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ data: { ids, job_format: jobFormat } }),
});
+2
View File
@@ -2,6 +2,8 @@
RECORDINGS_ENABLED=${RECORDINGS_ENABLED:-true}
pdm run alembic upgrade head
if [ "$RECORDINGS_ENABLED" == "false" ]; then
pdm run python -m api.backend.worker.job_worker
else