mirror of
https://github.com/jaypyles/Scraperr.git
synced 2025-11-24 18:16:41 +00:00
Compare commits
15 Commits
master
...
chore/refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e7a34ca0 | ||
|
|
73b1786516 | ||
|
|
881cf3cd99 | ||
|
|
6b487e5564 | ||
|
|
f6a96c3cc1 | ||
|
|
da81fb7b32 | ||
|
|
ae4147d204 | ||
|
|
e0bc0b4482 | ||
|
|
f1fca2a0ba | ||
|
|
e3359daa1e | ||
|
|
719d4c9f28 | ||
|
|
25b08d766e | ||
|
|
3c06a1ae14 | ||
|
|
1f426989af | ||
|
|
c1b3c68c76 |
15
.github/workflows/docker-image.yml
vendored
15
.github/workflows/docker-image.yml
vendored
@@ -26,36 +26,31 @@ jobs:
|
||||
run: |
|
||||
echo "Version is ${{ inputs.version }}"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push frontend (multi-arch)
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push frontend
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/frontend/Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/scraperr:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/scraperr:${{ inputs.version }}
|
||||
|
||||
- name: Build and push api (multi-arch)
|
||||
- name: Build and push api
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/api/Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/scraperr_api:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/scraperr_api:${{ inputs.version }}
|
||||
|
||||
12
.github/workflows/merge.yml
vendored
12
.github/workflows/merge.yml
vendored
@@ -10,14 +10,14 @@ on:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
# TODO: Renable once browser forge is fixed for camoufox, or else tests will never pass
|
||||
# tests:
|
||||
# uses: ./.github/workflows/tests.yml
|
||||
# secrets:
|
||||
# openai_key: ${{ secrets.OPENAI_KEY }}
|
||||
# discord_webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
tests:
|
||||
uses: ./.github/workflows/tests.yml
|
||||
secrets:
|
||||
openai_key: ${{ secrets.OPENAI_KEY }}
|
||||
discord_webhook_url: ${{ secrets.DISCORD_WEBHOOK_URL }}
|
||||
|
||||
version:
|
||||
needs: tests
|
||||
uses: ./.github/workflows/version.yml
|
||||
secrets:
|
||||
git_token: ${{ secrets.GPAT_TOKEN }}
|
||||
|
||||
5
.github/workflows/pr.yml
vendored
5
.github/workflows/pr.yml
vendored
@@ -8,6 +8,11 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
checkout:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
tests:
|
||||
uses: ./.github/workflows/tests.yml
|
||||
secrets:
|
||||
|
||||
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@@ -10,8 +10,6 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v3
|
||||
|
||||
- name: Set env
|
||||
run: echo "ENV=test" >> $GITHUB_ENV
|
||||
|
||||
@@ -22,7 +20,7 @@ jobs:
|
||||
run: pdm install
|
||||
|
||||
- name: Install playwright
|
||||
run: pdm run playwright install --with-deps
|
||||
run: pdm run playwright install
|
||||
|
||||
- name: Run tests
|
||||
run: PYTHONPATH=. pdm run pytest -v -ra api/backend/tests
|
||||
|
||||
8
.github/workflows/version.yml
vendored
8
.github/workflows/version.yml
vendored
@@ -19,7 +19,6 @@ jobs:
|
||||
|
||||
outputs:
|
||||
version: ${{ steps.set_version.outputs.version }}
|
||||
version_bump: ${{ steps.check_version_bump.outputs.version_bump }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -48,11 +47,10 @@ jobs:
|
||||
id: check_version_bump
|
||||
run: |
|
||||
COMMIT_MSG=$(git log -1 --pretty=%B)
|
||||
|
||||
if [[ $COMMIT_MSG =~ .*\[no\ bump\].* ]]; then
|
||||
echo "version_bump=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
if [[ $COMMIT_MSG =~ ^feat\(breaking\) ]]; then
|
||||
echo "version_bump=true" >> $GITHUB_OUTPUT
|
||||
elif [[ $COMMIT_MSG =~ .*\[no\ bump\].* ]]; then
|
||||
echo "version_bump=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Skip version bump
|
||||
|
||||
147
alembic.ini
147
alembic.ini
@@ -1,147 +0,0 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1 +0,0 @@
|
||||
Generic single-database configuration.
|
||||
103
alembic/env.py
103
alembic/env.py
@@ -1,103 +0,0 @@
|
||||
# STL
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
# PDM
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import pool, engine_from_config
|
||||
|
||||
# LOCAL
|
||||
from alembic import context
|
||||
from api.backend.database.base import Base
|
||||
from api.backend.database.models import Job, User, CronJob # type: ignore
|
||||
|
||||
load_dotenv()
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "api")))
|
||||
|
||||
# Load the raw async database URL
|
||||
raw_database_url = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///data/database.db")
|
||||
|
||||
# Map async dialects to sync ones
|
||||
driver_downgrade_map = {
|
||||
"sqlite+aiosqlite": "sqlite",
|
||||
"postgresql+asyncpg": "postgresql",
|
||||
"mysql+aiomysql": "mysql",
|
||||
}
|
||||
|
||||
# Extract scheme and convert if async
|
||||
for async_driver, sync_driver in driver_downgrade_map.items():
|
||||
if raw_database_url.startswith(async_driver + "://"):
|
||||
sync_database_url = raw_database_url.replace(async_driver, sync_driver, 1)
|
||||
break
|
||||
|
||||
else:
|
||||
# No async driver detected — assume it's already sync
|
||||
sync_database_url = raw_database_url
|
||||
|
||||
|
||||
# Apply it to Alembic config
|
||||
config = context.config
|
||||
config.set_main_option("sqlalchemy.url", sync_database_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,28 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,67 +0,0 @@
|
||||
"""initial revision
|
||||
|
||||
Revision ID: 6aa921d2e637
|
||||
Revises:
|
||||
Create Date: 2025-07-12 20:17:44.448034
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6aa921d2e637'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('disabled', sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('email')
|
||||
)
|
||||
op.create_table('jobs',
|
||||
sa.Column('id', sa.String(length=64), nullable=False),
|
||||
sa.Column('url', sa.String(length=2048), nullable=False),
|
||||
sa.Column('elements', sa.JSON(), nullable=False),
|
||||
sa.Column('user', sa.String(length=255), nullable=True),
|
||||
sa.Column('time_created', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('result', sa.JSON(), nullable=False),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('chat', sa.JSON(), nullable=True),
|
||||
sa.Column('job_options', sa.JSON(), nullable=True),
|
||||
sa.Column('agent_mode', sa.Boolean(), nullable=False),
|
||||
sa.Column('prompt', sa.String(length=1024), nullable=True),
|
||||
sa.Column('favorite', sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user'], ['users.email'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('cron_jobs',
|
||||
sa.Column('id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_email', sa.String(length=255), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('cron_expression', sa.String(length=255), nullable=False),
|
||||
sa.Column('time_created', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('time_updated', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_email'], ['users.email'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('cron_jobs')
|
||||
op.drop_table('jobs')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -7,7 +7,6 @@ from camoufox import AsyncCamoufox
|
||||
from playwright.async_api import Page
|
||||
|
||||
# LOCAL
|
||||
from api.backend.constants import RECORDINGS_ENABLED
|
||||
from api.backend.ai.clients import ask_ollama, ask_open_ai, open_ai_key
|
||||
from api.backend.job.models import CapturedElement
|
||||
from api.backend.worker.logger import LOG
|
||||
@@ -30,13 +29,11 @@ async def scrape_with_agent(agent_job: dict[str, Any]):
|
||||
LOG.info(f"Starting work for agent job: {agent_job}")
|
||||
pages = set()
|
||||
|
||||
proxy = None
|
||||
|
||||
if agent_job["job_options"]["proxies"]:
|
||||
proxy = random.choice(agent_job["job_options"]["proxies"])
|
||||
LOG.info(f"Using proxy: {proxy}")
|
||||
|
||||
async with AsyncCamoufox(headless=not RECORDINGS_ENABLED, proxy=proxy) as browser:
|
||||
async with AsyncCamoufox(headless=True) as browser:
|
||||
page: Page = await browser.new_page()
|
||||
|
||||
await add_custom_items(
|
||||
@@ -67,7 +64,7 @@ async def scrape_with_agent(agent_job: dict[str, Any]):
|
||||
xpaths = parse_response(response)
|
||||
|
||||
captured_elements = await capture_elements(
|
||||
page, xpaths, agent_job["job_options"].get("return_html", False)
|
||||
page, xpaths, agent_job["job_options"]["return_html"]
|
||||
)
|
||||
|
||||
final_url = page.url
|
||||
|
||||
@@ -15,6 +15,7 @@ from api.backend.scheduler import scheduler
|
||||
from api.backend.ai.ai_router import ai_router
|
||||
from api.backend.job.job_router import job_router
|
||||
from api.backend.auth.auth_router import auth_router
|
||||
from api.backend.database.startup import init_database
|
||||
from api.backend.stats.stats_router import stats_router
|
||||
from api.backend.job.cron_scheduling.cron_scheduling import start_cron_scheduler
|
||||
|
||||
@@ -35,8 +36,10 @@ async def lifespan(_: FastAPI):
|
||||
# Startup
|
||||
LOG.info("Starting application...")
|
||||
|
||||
init_database()
|
||||
|
||||
LOG.info("Starting cron scheduler...")
|
||||
await start_cron_scheduler(scheduler)
|
||||
start_cron_scheduler(scheduler)
|
||||
scheduler.start()
|
||||
|
||||
LOG.info("Cron scheduler started successfully")
|
||||
|
||||
@@ -6,11 +6,9 @@ from datetime import timedelta
|
||||
# PDM
|
||||
from fastapi import Depends, APIRouter, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.auth.schemas import User, Token, UserCreate
|
||||
from api.backend.database.base import AsyncSessionLocal, get_db
|
||||
from api.backend.auth.auth_utils import (
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_user,
|
||||
@@ -18,7 +16,7 @@ from api.backend.auth.auth_utils import (
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
)
|
||||
from api.backend.database.models import User as DatabaseUser
|
||||
from api.backend.database.common import update
|
||||
from api.backend.routers.handle_exceptions import handle_exceptions
|
||||
|
||||
auth_router = APIRouter()
|
||||
@@ -28,8 +26,8 @@ LOG = logging.getLogger("Auth")
|
||||
|
||||
@auth_router.post("/auth/token", response_model=Token)
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
||||
user = await authenticate_user(form_data.username, form_data.password, db)
|
||||
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
user = await authenticate_user(form_data.username, form_data.password)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -58,15 +56,8 @@ async def create_user(user: UserCreate):
|
||||
user_dict["hashed_password"] = hashed_password
|
||||
del user_dict["password"]
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
new_user = DatabaseUser(
|
||||
email=user.email,
|
||||
hashed_password=user_dict["hashed_password"],
|
||||
full_name=user.full_name,
|
||||
)
|
||||
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
query = "INSERT INTO users (email, hashed_password, full_name) VALUES (?, ?, ?)"
|
||||
_ = update(query, (user_dict["email"], hashed_password, user_dict["full_name"]))
|
||||
|
||||
return user_dict
|
||||
|
||||
|
||||
@@ -8,15 +8,12 @@ from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from passlib.context import CryptContext
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.auth.schemas import User, UserInDB, TokenData
|
||||
from api.backend.database.base import get_db
|
||||
from api.backend.database.models import User as UserModel
|
||||
from api.backend.database.common import query
|
||||
|
||||
LOG = logging.getLogger("Auth")
|
||||
|
||||
@@ -40,24 +37,18 @@ def get_password_hash(password: str):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def get_user(session: AsyncSession, email: str) -> UserInDB | None:
|
||||
stmt = select(UserModel).where(UserModel.email == email)
|
||||
result = await session.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
async def get_user(email: str):
|
||||
user_query = "SELECT * FROM users WHERE email = ?"
|
||||
user = query(user_query, (email,))[0]
|
||||
|
||||
if not user:
|
||||
return None
|
||||
return
|
||||
|
||||
return UserInDB(
|
||||
email=str(user.email),
|
||||
hashed_password=str(user.hashed_password),
|
||||
full_name=str(user.full_name),
|
||||
disabled=bool(user.disabled),
|
||||
)
|
||||
return UserInDB(**user)
|
||||
|
||||
|
||||
async def authenticate_user(email: str, password: str, db: AsyncSession):
|
||||
user = await get_user(db, email)
|
||||
async def authenticate_user(email: str, password: str):
|
||||
user = await get_user(email)
|
||||
|
||||
if not user:
|
||||
return False
|
||||
@@ -83,9 +74,7 @@ def create_access_token(
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
):
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
LOG.debug(f"Getting current user with token: {token}")
|
||||
|
||||
if not token:
|
||||
@@ -93,7 +82,7 @@ async def get_current_user(
|
||||
return EMPTY_USER
|
||||
|
||||
if len(token.split(".")) != 3:
|
||||
LOG.debug(f"Malformed token: {token}")
|
||||
LOG.error(f"Malformed token: {token}")
|
||||
return EMPTY_USER
|
||||
|
||||
try:
|
||||
@@ -128,7 +117,7 @@ async def get_current_user(
|
||||
LOG.error(f"Exception occurred: {e}")
|
||||
return EMPTY_USER
|
||||
|
||||
user = await get_user(db, email=token_data.email or "")
|
||||
user = await get_user(email=token_data.email or "")
|
||||
|
||||
if user is None:
|
||||
return EMPTY_USER
|
||||
@@ -136,7 +125,7 @@ async def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def require_user(db: AsyncSession, token: str = Depends(oauth2_scheme)):
|
||||
async def require_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -161,7 +150,7 @@ async def require_user(db: AsyncSession, token: str = Depends(oauth2_scheme)):
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = await get_user(db, email=token_data.email or "")
|
||||
user = await get_user(email=token_data.email or "")
|
||||
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///data/database.db")
|
||||
DATABASE_PATH = "data/database.db"
|
||||
RECORDINGS_DIR = Path("media/recordings")
|
||||
RECORDINGS_ENABLED = os.getenv("RECORDINGS_ENABLED", "true").lower() == "true"
|
||||
MEDIA_DIR = Path("media")
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
# LOCAL
|
||||
from .common import insert, update, connect
|
||||
from .schema import INIT_QUERY
|
||||
|
||||
__all__ = ["insert", "update", "INIT_QUERY", "connect"]
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
# STL
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# PDM
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
# LOCAL
|
||||
from api.backend.constants import DATABASE_URL
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=engine,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
89
api/backend/database/common.py
Normal file
89
api/backend/database/common.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# STL
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Any, Optional
|
||||
|
||||
# LOCAL
|
||||
from api.backend.constants import DATABASE_PATH
|
||||
from api.backend.database.utils import format_json, format_sql_row_to_python
|
||||
|
||||
LOG = logging.getLogger("Database")
|
||||
|
||||
|
||||
def connect():
|
||||
connection = sqlite3.connect(DATABASE_PATH)
|
||||
connection.set_trace_callback(print)
|
||||
cursor = connection.cursor()
|
||||
return cursor
|
||||
|
||||
|
||||
def insert(query: str, values: tuple[Any, ...]):
|
||||
connection = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = connection.cursor()
|
||||
copy = list(values)
|
||||
format_json(copy)
|
||||
|
||||
try:
|
||||
_ = cursor.execute(query, copy)
|
||||
connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
LOG.error(f"An error occurred: {e}")
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
def query(query: str, values: Optional[tuple[Any, ...]] = None):
|
||||
connection = sqlite3.connect(DATABASE_PATH)
|
||||
connection.row_factory = sqlite3.Row
|
||||
cursor = connection.cursor()
|
||||
rows = []
|
||||
try:
|
||||
if values:
|
||||
_ = cursor.execute(query, values)
|
||||
else:
|
||||
_ = cursor.execute(query)
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
formatted_rows: list[dict[str, Any]] = []
|
||||
|
||||
for row in rows:
|
||||
row = dict(row)
|
||||
formatted_row = format_sql_row_to_python(row)
|
||||
formatted_rows.append(formatted_row)
|
||||
|
||||
return formatted_rows
|
||||
|
||||
|
||||
def update(query: str, values: Optional[tuple[Any, ...]] = None):
|
||||
connection = sqlite3.connect(DATABASE_PATH)
|
||||
cursor = connection.cursor()
|
||||
|
||||
copy = None
|
||||
|
||||
if values:
|
||||
copy = list(values)
|
||||
format_json(copy)
|
||||
|
||||
try:
|
||||
if copy:
|
||||
res = cursor.execute(query, copy)
|
||||
else:
|
||||
res = cursor.execute(query)
|
||||
connection.commit()
|
||||
return res.rowcount
|
||||
except sqlite3.Error as e:
|
||||
LOG.error(f"An error occurred: {e}")
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
return 0
|
||||
@@ -1,65 +0,0 @@
|
||||
# PDM
|
||||
from sqlalchemy import JSON, Column, String, Boolean, DateTime, ForeignKey, func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# LOCAL
|
||||
from api.backend.database.base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), primary_key=True, nullable=False)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=True)
|
||||
disabled = Column(Boolean, default=False)
|
||||
|
||||
jobs = relationship("Job", back_populates="user_obj", cascade="all, delete-orphan")
|
||||
cron_jobs = relationship(
|
||||
"CronJob", back_populates="user_obj", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Job(Base):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id = Column(String(64), primary_key=True, nullable=False)
|
||||
url = Column(String(2048), nullable=False)
|
||||
elements = Column(JSON, nullable=False)
|
||||
user = Column(String(255), ForeignKey("users.email"), nullable=True)
|
||||
time_created = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
result = Column(JSON, nullable=False)
|
||||
status = Column(String(50), nullable=False)
|
||||
chat = Column(JSON, nullable=True)
|
||||
job_options = Column(JSON, nullable=True)
|
||||
agent_mode = Column(Boolean, default=False, nullable=False)
|
||||
prompt = Column(String(1024), nullable=True)
|
||||
favorite = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
user_obj = relationship("User", back_populates="jobs")
|
||||
cron_jobs = relationship(
|
||||
"CronJob", back_populates="job_obj", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class CronJob(Base):
|
||||
__tablename__ = "cron_jobs"
|
||||
|
||||
id = Column(String(64), primary_key=True, nullable=False)
|
||||
user_email = Column(String(255), ForeignKey("users.email"), nullable=False)
|
||||
job_id = Column(String(64), ForeignKey("jobs.id"), nullable=False)
|
||||
cron_expression = Column(String(255), nullable=False)
|
||||
time_created = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
time_updated = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
user_obj = relationship("User", back_populates="cron_jobs")
|
||||
job_obj = relationship("Job", back_populates="cron_jobs")
|
||||
@@ -0,0 +1,4 @@
|
||||
# LOCAL
|
||||
from .job.job_queries import DELETE_JOB_QUERY, JOB_INSERT_QUERY
|
||||
|
||||
__all__ = ["JOB_INSERT_QUERY", "DELETE_JOB_QUERY"]
|
||||
|
||||
@@ -2,64 +2,57 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
# PDM
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update as sql_update
|
||||
|
||||
# LOCAL
|
||||
from api.backend.database.base import AsyncSessionLocal
|
||||
from api.backend.database.models import Job
|
||||
from api.backend.database.utils import format_list_for_query
|
||||
from api.backend.database.common import query, insert, update
|
||||
|
||||
JOB_INSERT_QUERY = """
|
||||
INSERT INTO jobs
|
||||
(id, url, elements, user, time_created, result, status, chat, job_options, agent_mode, prompt)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
DELETE_JOB_QUERY = """
|
||||
DELETE FROM jobs WHERE id IN ()
|
||||
"""
|
||||
|
||||
LOG = logging.getLogger("Database")
|
||||
|
||||
|
||||
async def insert_job(item: dict[str, Any]) -> None:
|
||||
async with AsyncSessionLocal() as session:
|
||||
job = Job(
|
||||
id=item["id"],
|
||||
url=item["url"],
|
||||
elements=item["elements"],
|
||||
user=item["user"],
|
||||
time_created=item["time_created"],
|
||||
result=item["result"],
|
||||
status=item["status"],
|
||||
chat=item["chat"],
|
||||
job_options=item["job_options"],
|
||||
agent_mode=item["agent_mode"],
|
||||
prompt=item["prompt"],
|
||||
)
|
||||
session.add(job)
|
||||
await session.commit()
|
||||
LOG.info(f"Inserted item: {item}")
|
||||
def insert_job(item: dict[str, Any]) -> None:
|
||||
insert(
|
||||
JOB_INSERT_QUERY,
|
||||
(
|
||||
item["id"],
|
||||
item["url"],
|
||||
item["elements"],
|
||||
item["user"],
|
||||
item["time_created"],
|
||||
item["result"],
|
||||
item["status"],
|
||||
item["chat"],
|
||||
item["job_options"],
|
||||
item["agent_mode"],
|
||||
item["prompt"],
|
||||
),
|
||||
)
|
||||
LOG.info(f"Inserted item: {item}")
|
||||
|
||||
|
||||
async def get_queued_job():
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = (
|
||||
select(Job)
|
||||
.where(Job.status == "Queued")
|
||||
.order_by(Job.time_created.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
job = result.scalars().first()
|
||||
queued_job_query = (
|
||||
"SELECT * FROM jobs WHERE status = 'Queued' ORDER BY time_created DESC LIMIT 1"
|
||||
)
|
||||
|
||||
if job:
|
||||
LOG.info(f"Got queued job: {job}")
|
||||
|
||||
return job
|
||||
res = query(queued_job_query)
|
||||
LOG.info(f"Got queued job: {res}")
|
||||
return res[0] if res else None
|
||||
|
||||
|
||||
async def update_job(ids: list[str], updates: dict[str, Any]):
|
||||
if not updates:
|
||||
return
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = sql_update(Job).where(Job.id.in_(ids)).values(**updates)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
LOG.debug(f"Updated job count: {result.rowcount}")
|
||||
async def update_job(ids: list[str], field: str, value: Any):
|
||||
query = f"UPDATE jobs SET {field} = ? WHERE id IN {format_list_for_query(ids)}"
|
||||
res = update(query, tuple([value] + ids))
|
||||
LOG.info(f"Updated job: {res}")
|
||||
|
||||
|
||||
async def delete_jobs(jobs: list[str]):
|
||||
@@ -67,9 +60,9 @@ async def delete_jobs(jobs: list[str]):
|
||||
LOG.info("No jobs to delete.")
|
||||
return False
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = sql_delete(Job).where(Job.id.in_(jobs))
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
LOG.info(f"Deleted jobs count: {result.rowcount}")
|
||||
return result.rowcount
|
||||
query = f"DELETE FROM jobs WHERE id IN {format_list_for_query(jobs)}"
|
||||
res = update(query, tuple(jobs))
|
||||
|
||||
LOG.info(f"Deleted jobs: {res}")
|
||||
|
||||
return res
|
||||
|
||||
@@ -1,43 +1,41 @@
|
||||
# PDM
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.database.models import Job
|
||||
from api.backend.database.common import query
|
||||
|
||||
|
||||
async def average_elements_per_link(session: AsyncSession, user_email: str):
|
||||
date_func = func.date(Job.time_created)
|
||||
async def average_elements_per_link(user: str):
|
||||
job_query = """
|
||||
SELECT
|
||||
DATE(time_created) AS date,
|
||||
AVG(json_array_length(elements)) AS average_elements,
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
jobs
|
||||
WHERE
|
||||
status = 'Completed' AND user = ?
|
||||
GROUP BY
|
||||
DATE(time_created)
|
||||
ORDER BY
|
||||
date ASC;
|
||||
"""
|
||||
results = query(job_query, (user,))
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
date_func.label("date"),
|
||||
func.avg(func.json_array_length(Job.elements)).label("average_elements"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.where(Job.status == "Completed", Job.user == user_email)
|
||||
.group_by(date_func)
|
||||
.order_by("date")
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
return results
|
||||
|
||||
|
||||
async def get_jobs_per_day(session: AsyncSession, user_email: str):
|
||||
date_func = func.date(Job.time_created)
|
||||
async def get_jobs_per_day(user: str):
|
||||
job_query = """
|
||||
SELECT
|
||||
DATE(time_created) AS date,
|
||||
COUNT(*) AS job_count
|
||||
FROM
|
||||
jobs
|
||||
WHERE
|
||||
status = 'Completed' AND user = ?
|
||||
GROUP BY
|
||||
DATE(time_created)
|
||||
ORDER BY
|
||||
date ASC;
|
||||
"""
|
||||
results = query(job_query, (user,))
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
date_func.label("date"),
|
||||
func.count().label("job_count"),
|
||||
)
|
||||
.where(Job.status == "Completed", Job.user == user_email)
|
||||
.group_by(date_func)
|
||||
.order_by("date")
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
return results
|
||||
|
||||
3
api/backend/database/schema/__init__.py
Normal file
3
api/backend/database/schema/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .schema import INIT_QUERY
|
||||
|
||||
__all__ = ["INIT_QUERY"]
|
||||
34
api/backend/database/schema/schema.py
Normal file
34
api/backend/database/schema/schema.py
Normal file
@@ -0,0 +1,34 @@
|
||||
INIT_QUERY = """
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id STRING PRIMARY KEY NOT NULL,
|
||||
url STRING NOT NULL,
|
||||
elements JSON NOT NULL,
|
||||
user STRING,
|
||||
time_created DATETIME NOT NULL,
|
||||
result JSON NOT NULL,
|
||||
status STRING NOT NULL,
|
||||
chat JSON,
|
||||
job_options JSON
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
email STRING PRIMARY KEY NOT NULL,
|
||||
hashed_password STRING NOT NULL,
|
||||
full_name STRING,
|
||||
disabled BOOLEAN
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cron_jobs (
|
||||
id STRING PRIMARY KEY NOT NULL,
|
||||
user_email STRING NOT NULL,
|
||||
job_id STRING NOT NULL,
|
||||
cron_expression STRING NOT NULL,
|
||||
time_created DATETIME NOT NULL,
|
||||
time_updated DATETIME NOT NULL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id)
|
||||
);
|
||||
|
||||
ALTER TABLE jobs ADD COLUMN agent_mode BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE jobs ADD COLUMN prompt STRING;
|
||||
ALTER TABLE jobs ADD COLUMN favorite BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
"""
|
||||
@@ -1,8 +1,6 @@
|
||||
# STL
|
||||
import logging
|
||||
|
||||
# PDM
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
import sqlite3
|
||||
|
||||
# LOCAL
|
||||
from api.backend.constants import (
|
||||
@@ -11,46 +9,61 @@ from api.backend.constants import (
|
||||
DEFAULT_USER_PASSWORD,
|
||||
DEFAULT_USER_FULL_NAME,
|
||||
)
|
||||
from api.backend.database.base import Base, AsyncSessionLocal, engine
|
||||
from api.backend.auth.auth_utils import get_password_hash
|
||||
from api.backend.database.models import User
|
||||
from api.backend.database.common import insert, connect
|
||||
from api.backend.database.schema import INIT_QUERY
|
||||
|
||||
LOG = logging.getLogger("Database")
|
||||
|
||||
async def init_database():
|
||||
LOG.info("Creating database schema...")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
def execute_startup_query():
|
||||
cursor = connect()
|
||||
|
||||
for query in INIT_QUERY.strip().split(";"):
|
||||
query = query.strip()
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
try:
|
||||
LOG.info(f"Executing query: {query}")
|
||||
_ = cursor.execute(query)
|
||||
|
||||
except sqlite3.OperationalError as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
LOG.warning(f"Skipping duplicate column error: {e}")
|
||||
continue
|
||||
else:
|
||||
LOG.error(f"Error executing query: {query}")
|
||||
raise
|
||||
|
||||
cursor.close()
|
||||
|
||||
|
||||
def init_database():
|
||||
execute_startup_query()
|
||||
|
||||
if not REGISTRATION_ENABLED:
|
||||
default_user_email = DEFAULT_USER_EMAIL
|
||||
default_user_password = DEFAULT_USER_PASSWORD
|
||||
default_user_full_name = DEFAULT_USER_FULL_NAME
|
||||
|
||||
if not (default_user_email and default_user_password and default_user_full_name):
|
||||
LOG.error("DEFAULT_USER_* env vars are not set!")
|
||||
if (
|
||||
not default_user_email
|
||||
or not default_user_password
|
||||
or not default_user_full_name
|
||||
):
|
||||
LOG.error(
|
||||
"DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD, or DEFAULT_USER_FULL_NAME is not set!"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, default_user_email)
|
||||
if user:
|
||||
LOG.info("Default user already exists. Skipping creation.")
|
||||
return
|
||||
|
||||
LOG.info("Creating default user...")
|
||||
new_user = User(
|
||||
email=default_user_email,
|
||||
hashed_password=get_password_hash(default_user_password),
|
||||
full_name=default_user_full_name,
|
||||
disabled=False,
|
||||
)
|
||||
|
||||
try:
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
LOG.info(f"Created default user: {default_user_email}")
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
LOG.warning(f"Could not create default user (already exists?): {e}")
|
||||
|
||||
query = "INSERT INTO users (email, hashed_password, full_name) VALUES (?, ?, ?)"
|
||||
_ = insert(
|
||||
query,
|
||||
(
|
||||
default_user_email,
|
||||
get_password_hash(default_user_password),
|
||||
default_user_full_name,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# STL
|
||||
import json
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def format_list_for_query(ids: list[str]):
|
||||
@@ -29,9 +28,3 @@ def format_json(items: list[Any]):
|
||||
if isinstance(item, (dict, list)):
|
||||
formatted_item = json.dumps(item)
|
||||
items[idx] = formatted_item
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> datetime:
|
||||
if dt_str.endswith("Z"):
|
||||
dt_str = dt_str.replace("Z", "+00:00") # valid ISO format for UTC
|
||||
return datetime.fromisoformat(dt_str)
|
||||
|
||||
@@ -2,74 +2,83 @@
|
||||
import uuid
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
# PDM
|
||||
from sqlalchemy import select
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
# LOCAL
|
||||
from api.backend.job import insert as insert_job
|
||||
from api.backend.database.base import AsyncSessionLocal
|
||||
from api.backend.database.models import Job, CronJob
|
||||
from api.backend.schemas.cron import CronJob
|
||||
from api.backend.database.common import query, insert
|
||||
|
||||
LOG = logging.getLogger("Cron")
|
||||
|
||||
|
||||
async def insert_cron_job(cron_job: CronJob) -> bool:
|
||||
async with AsyncSessionLocal() as session:
|
||||
session.add(cron_job)
|
||||
await session.commit()
|
||||
def insert_cron_job(cron_job: CronJob):
|
||||
query = """
|
||||
INSERT INTO cron_jobs (id, user_email, job_id, cron_expression, time_created, time_updated)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
values = (
|
||||
cron_job.id,
|
||||
cron_job.user_email,
|
||||
cron_job.job_id,
|
||||
cron_job.cron_expression,
|
||||
cron_job.time_created,
|
||||
cron_job.time_updated,
|
||||
)
|
||||
|
||||
insert(query, values)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_cron_job(id: str, user_email: str) -> bool:
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = select(CronJob).where(CronJob.id == id, CronJob.user_email == user_email)
|
||||
result = await session.execute(stmt)
|
||||
cron_job = result.scalars().first()
|
||||
if cron_job:
|
||||
await session.delete(cron_job)
|
||||
await session.commit()
|
||||
def delete_cron_job(id: str, user_email: str):
|
||||
query = """
|
||||
DELETE FROM cron_jobs
|
||||
WHERE id = ? AND user_email = ?
|
||||
"""
|
||||
|
||||
values = (id, user_email)
|
||||
insert(query, values)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_cron_jobs(user_email: str) -> List[CronJob]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = select(CronJob).where(CronJob.user_email == user_email)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
def get_cron_jobs(user_email: str):
|
||||
cron_jobs = query("SELECT * FROM cron_jobs WHERE user_email = ?", (user_email,))
|
||||
|
||||
return cron_jobs
|
||||
|
||||
|
||||
async def get_all_cron_jobs() -> List[CronJob]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = select(CronJob)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
def get_all_cron_jobs():
|
||||
cron_jobs = query("SELECT * FROM cron_jobs")
|
||||
|
||||
return cron_jobs
|
||||
|
||||
|
||||
async def insert_job_from_cron_job(job: dict[str, Any]):
|
||||
async with AsyncSessionLocal() as session:
|
||||
await insert_job(
|
||||
{
|
||||
**job,
|
||||
"id": uuid.uuid4().hex,
|
||||
"status": "Queued",
|
||||
"result": "",
|
||||
"chat": None,
|
||||
"time_created": datetime.datetime.now(datetime.timezone.utc),
|
||||
"time_updated": datetime.datetime.now(datetime.timezone.utc),
|
||||
},
|
||||
session,
|
||||
)
|
||||
def insert_job_from_cron_job(job: dict[str, Any]):
|
||||
insert_job(
|
||||
{
|
||||
**job,
|
||||
"id": uuid.uuid4().hex,
|
||||
"status": "Queued",
|
||||
"result": "",
|
||||
"chat": None,
|
||||
"time_created": datetime.datetime.now(),
|
||||
"time_updated": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_cron_job_trigger(cron_expression: str):
|
||||
expression_parts = cron_expression.split()
|
||||
|
||||
if len(expression_parts) != 5:
|
||||
LOG.warning(f"Invalid cron expression: {cron_expression}")
|
||||
print(f"Invalid cron expression: {cron_expression}")
|
||||
return None
|
||||
|
||||
minute, hour, day, month, day_of_week = expression_parts
|
||||
@@ -79,37 +88,19 @@ def get_cron_job_trigger(cron_expression: str):
|
||||
)
|
||||
|
||||
|
||||
async def start_cron_scheduler(scheduler: AsyncIOScheduler):
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = select(CronJob)
|
||||
result = await session.execute(stmt)
|
||||
cron_jobs = result.scalars().all()
|
||||
def start_cron_scheduler(scheduler: BackgroundScheduler):
|
||||
cron_jobs = get_all_cron_jobs()
|
||||
|
||||
LOG.info(f"Cron jobs: {cron_jobs}")
|
||||
LOG.info(f"Cron jobs: {cron_jobs}")
|
||||
|
||||
for cron_job in cron_jobs:
|
||||
stmt = select(Job).where(Job.id == cron_job.job_id)
|
||||
result = await session.execute(stmt)
|
||||
queried_job = result.scalars().first()
|
||||
for job in cron_jobs:
|
||||
queried_job = query("SELECT * FROM jobs WHERE id = ?", (job["job_id"],))
|
||||
|
||||
LOG.info(f"Adding job: {queried_job}")
|
||||
LOG.info(f"Adding job: {queried_job}")
|
||||
|
||||
trigger = get_cron_job_trigger(cron_job.cron_expression) # type: ignore
|
||||
if not trigger:
|
||||
continue
|
||||
|
||||
job_dict = (
|
||||
{
|
||||
c.key: getattr(queried_job, c.key)
|
||||
for c in queried_job.__table__.columns
|
||||
}
|
||||
if queried_job
|
||||
else {}
|
||||
)
|
||||
|
||||
scheduler.add_job(
|
||||
insert_job_from_cron_job,
|
||||
trigger,
|
||||
id=cron_job.id,
|
||||
args=[job_dict],
|
||||
)
|
||||
scheduler.add_job(
|
||||
insert_job_from_cron_job,
|
||||
get_cron_job_trigger(job["cron_expression"]),
|
||||
id=job["id"],
|
||||
args=[queried_job[0]],
|
||||
)
|
||||
|
||||
@@ -1,103 +1,51 @@
|
||||
# STL
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
# PDM
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.database.base import AsyncSessionLocal
|
||||
from api.backend.database.models import Job
|
||||
from api.backend.database.utils import format_list_for_query
|
||||
from api.backend.database.common import query as common_query
|
||||
from api.backend.database.common import insert as common_insert
|
||||
from api.backend.database.common import update as common_update
|
||||
from api.backend.database.queries.job.job_queries import JOB_INSERT_QUERY
|
||||
|
||||
LOG = logging.getLogger("Job")
|
||||
|
||||
|
||||
async def insert(item: dict[str, Any], db: AsyncSession) -> None:
|
||||
existing = await db.get(Job, item["id"])
|
||||
if existing:
|
||||
await multi_field_update_job(
|
||||
def insert(item: dict[str, Any]) -> None:
|
||||
common_insert(
|
||||
JOB_INSERT_QUERY,
|
||||
(
|
||||
item["id"],
|
||||
{
|
||||
"agent_mode": item["agent_mode"],
|
||||
"prompt": item["prompt"],
|
||||
"job_options": item["job_options"],
|
||||
"elements": item["elements"],
|
||||
"status": "Queued",
|
||||
"result": [],
|
||||
"time_created": datetime.datetime.now(datetime.timezone.utc),
|
||||
"chat": None,
|
||||
},
|
||||
db,
|
||||
)
|
||||
return
|
||||
|
||||
job = Job(
|
||||
id=item["id"],
|
||||
url=item["url"],
|
||||
elements=item["elements"],
|
||||
user=item["user"],
|
||||
time_created=datetime.datetime.now(datetime.timezone.utc),
|
||||
result=item["result"],
|
||||
status=item["status"],
|
||||
chat=item["chat"],
|
||||
job_options=item["job_options"],
|
||||
agent_mode=item["agent_mode"],
|
||||
prompt=item["prompt"],
|
||||
item["url"],
|
||||
item["elements"],
|
||||
item["user"],
|
||||
item["time_created"],
|
||||
item["result"],
|
||||
item["status"],
|
||||
item["chat"],
|
||||
item["job_options"],
|
||||
item["agent_mode"],
|
||||
item["prompt"],
|
||||
),
|
||||
)
|
||||
|
||||
db.add(job)
|
||||
await db.commit()
|
||||
LOG.debug(f"Inserted item: {item}")
|
||||
|
||||
|
||||
async def check_for_job_completion(id: str) -> dict[str, Any]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
job = await session.get(Job, id)
|
||||
return job.__dict__ if job else {}
|
||||
|
||||
|
||||
async def get_queued_job():
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = (
|
||||
select(Job)
|
||||
.where(Job.status == "Queued")
|
||||
.order_by(Job.time_created.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
job = result.scalars().first()
|
||||
LOG.debug(f"Got queued job: {job}")
|
||||
return job.__dict__ if job else None
|
||||
query = (
|
||||
"SELECT * FROM jobs WHERE status = 'Queued' ORDER BY time_created DESC LIMIT 1"
|
||||
)
|
||||
res = common_query(query)
|
||||
LOG.debug(f"Got queued job: {res}")
|
||||
return res[0] if res else None
|
||||
|
||||
|
||||
async def update_job(ids: list[str], field: str, value: Any):
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = sql_update(Job).where(Job.id.in_(ids)).values({field: value})
|
||||
res = await session.execute(stmt)
|
||||
await session.commit()
|
||||
LOG.debug(f"Updated job count: {res.rowcount}")
|
||||
|
||||
|
||||
async def multi_field_update_job(
|
||||
id: str, fields: dict[str, Any], session: AsyncSession | None = None
|
||||
):
|
||||
close_session = False
|
||||
if not session:
|
||||
session = AsyncSessionLocal()
|
||||
close_session = True
|
||||
|
||||
try:
|
||||
stmt = sql_update(Job).where(Job.id == id).values(**fields)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
LOG.debug(f"Updated job {id} fields: {fields}")
|
||||
finally:
|
||||
if close_session:
|
||||
await session.close()
|
||||
query = f"UPDATE jobs SET {field} = ? WHERE id IN {format_list_for_query(ids)}"
|
||||
res = common_update(query, tuple([value] + ids))
|
||||
LOG.debug(f"Updated job: {res}")
|
||||
|
||||
|
||||
async def delete_jobs(jobs: list[str]):
|
||||
@@ -105,9 +53,7 @@ async def delete_jobs(jobs: list[str]):
|
||||
LOG.debug("No jobs to delete.")
|
||||
return False
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
stmt = sql_delete(Job).where(Job.id.in_(jobs))
|
||||
res = await session.execute(stmt)
|
||||
await session.commit()
|
||||
LOG.debug(f"Deleted jobs: {res.rowcount}")
|
||||
return res.rowcount > 0
|
||||
query = f"DELETE FROM jobs WHERE id IN {format_list_for_query(jobs)}"
|
||||
res = common_update(query, tuple(jobs))
|
||||
|
||||
return res > 0
|
||||
|
||||
@@ -8,10 +8,8 @@ from io import StringIO
|
||||
|
||||
# PDM
|
||||
from fastapi import Depends, APIRouter
|
||||
from sqlalchemy import select
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from apscheduler.triggers.cron import CronTrigger # type: ignore
|
||||
|
||||
# LOCAL
|
||||
@@ -20,12 +18,10 @@ from api.backend.constants import MEDIA_DIR, MEDIA_TYPES, RECORDINGS_DIR
|
||||
from api.backend.scheduler import scheduler
|
||||
from api.backend.schemas.job import Job, UpdateJobs, DownloadJob, DeleteScrapeJobs
|
||||
from api.backend.auth.schemas import User
|
||||
from api.backend.schemas.cron import CronJob as PydanticCronJob
|
||||
from api.backend.schemas.cron import DeleteCronJob
|
||||
from api.backend.database.base import get_db
|
||||
from api.backend.schemas.cron import CronJob, DeleteCronJob
|
||||
from api.backend.database.utils import format_list_for_query
|
||||
from api.backend.auth.auth_utils import get_current_user
|
||||
from api.backend.database.models import Job as DatabaseJob
|
||||
from api.backend.database.models import CronJob
|
||||
from api.backend.database.common import query
|
||||
from api.backend.job.utils.text_utils import clean_text
|
||||
from api.backend.job.models.job_options import FetchOptions
|
||||
from api.backend.routers.handle_exceptions import handle_exceptions
|
||||
@@ -47,20 +43,20 @@ job_router = APIRouter()
|
||||
@job_router.post("/update")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def update(update_jobs: UpdateJobs, _: User = Depends(get_current_user)):
|
||||
"""Used to update jobs"""
|
||||
await update_job(update_jobs.ids, update_jobs.field, update_jobs.value)
|
||||
return {"message": "Jobs updated successfully"}
|
||||
|
||||
return JSONResponse(content={"message": "Jobs updated successfully."})
|
||||
|
||||
|
||||
@job_router.post("/submit-scrape-job")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def submit_scrape_job(job: Job, db: AsyncSession = Depends(get_db)):
|
||||
async def submit_scrape_job(job: Job):
|
||||
LOG.info(f"Recieved job: {job}")
|
||||
|
||||
if not job.id:
|
||||
job.id = uuid.uuid4().hex
|
||||
|
||||
job.id = uuid.uuid4().hex
|
||||
job_dict = job.model_dump()
|
||||
await insert(job_dict, db)
|
||||
insert(job_dict)
|
||||
|
||||
return JSONResponse(
|
||||
content={"id": job.id, "message": "Job submitted successfully."}
|
||||
@@ -70,49 +66,32 @@ async def submit_scrape_job(job: Job, db: AsyncSession = Depends(get_db)):
|
||||
@job_router.post("/retrieve-scrape-jobs")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def retrieve_scrape_jobs(
|
||||
fetch_options: FetchOptions,
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetch_options: FetchOptions, user: User = Depends(get_current_user)
|
||||
):
|
||||
LOG.info(
|
||||
f"Retrieving jobs for account: {user.email if user.email else 'Guest User'}"
|
||||
)
|
||||
if fetch_options.chat:
|
||||
stmt = select(DatabaseJob.chat).filter(DatabaseJob.user == user.email)
|
||||
else:
|
||||
stmt = select(DatabaseJob).filter(DatabaseJob.user == user.email)
|
||||
|
||||
results = await db.execute(stmt)
|
||||
rows = results.all() if fetch_options.chat else results.scalars().all()
|
||||
|
||||
return JSONResponse(content=jsonable_encoder(rows[::-1]))
|
||||
LOG.info(f"Retrieving jobs for account: {user.email}")
|
||||
ATTRIBUTES = "chat" if fetch_options.chat else "*"
|
||||
job_query = f"SELECT {ATTRIBUTES} FROM jobs WHERE user = ?"
|
||||
results = query(job_query, (user.email,))
|
||||
return JSONResponse(content=jsonable_encoder(results[::-1]))
|
||||
|
||||
|
||||
@job_router.get("/job/{id}")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def job(
|
||||
id: str, user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
async def job(id: str, user: User = Depends(get_current_user)):
|
||||
LOG.info(f"Retrieving jobs for account: {user.email}")
|
||||
|
||||
stmt = select(DatabaseJob).filter(
|
||||
DatabaseJob.user == user.email, DatabaseJob.id == id
|
||||
)
|
||||
|
||||
results = await db.execute(stmt)
|
||||
|
||||
return JSONResponse(
|
||||
content=jsonable_encoder([job.__dict__ for job in results.scalars().all()])
|
||||
)
|
||||
job_query = "SELECT * FROM jobs WHERE user = ? AND id = ?"
|
||||
results = query(job_query, (user.email, id))
|
||||
return JSONResponse(content=jsonable_encoder(results))
|
||||
|
||||
|
||||
@job_router.post("/download")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def download(download_job: DownloadJob, db: AsyncSession = Depends(get_db)):
|
||||
async def download(download_job: DownloadJob):
|
||||
LOG.info(f"Downloading job with ids: {download_job.ids}")
|
||||
stmt = select(DatabaseJob).where(DatabaseJob.id.in_(download_job.ids))
|
||||
result = await db.execute(stmt)
|
||||
results = [job.__dict__ for job in result.scalars().all()]
|
||||
job_query = (
|
||||
f"SELECT * FROM jobs WHERE id IN {format_list_for_query(download_job.ids)}"
|
||||
)
|
||||
results = query(job_query, tuple(download_job.ids))
|
||||
|
||||
if download_job.job_format == "csv":
|
||||
csv_buffer = StringIO()
|
||||
@@ -170,12 +149,10 @@ async def download(download_job: DownloadJob, db: AsyncSession = Depends(get_db)
|
||||
|
||||
@job_router.get("/job/{id}/convert-to-csv")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def convert_to_csv(id: str, db: AsyncSession = Depends(get_db)):
|
||||
stmt = select(DatabaseJob).filter(DatabaseJob.id == id)
|
||||
results = await db.execute(stmt)
|
||||
jobs = results.scalars().all()
|
||||
|
||||
return JSONResponse(content=clean_job_format([job.__dict__ for job in jobs]))
|
||||
async def convert_to_csv(id: str):
|
||||
job_query = f"SELECT * FROM jobs WHERE id = ?"
|
||||
results = query(job_query, (id,))
|
||||
return JSONResponse(content=clean_job_format(results))
|
||||
|
||||
|
||||
@job_router.post("/delete-scrape-jobs")
|
||||
@@ -191,34 +168,25 @@ async def delete(delete_scrape_jobs: DeleteScrapeJobs):
|
||||
|
||||
@job_router.post("/schedule-cron-job")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def schedule_cron_job(
|
||||
cron_job: PydanticCronJob,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def schedule_cron_job(cron_job: CronJob):
|
||||
if not cron_job.id:
|
||||
cron_job.id = uuid.uuid4().hex
|
||||
|
||||
now = datetime.datetime.now()
|
||||
if not cron_job.time_created:
|
||||
cron_job.time_created = now
|
||||
cron_job.time_created = datetime.datetime.now()
|
||||
|
||||
if not cron_job.time_updated:
|
||||
cron_job.time_updated = now
|
||||
cron_job.time_updated = datetime.datetime.now()
|
||||
|
||||
await insert_cron_job(CronJob(**cron_job.model_dump()))
|
||||
insert_cron_job(cron_job)
|
||||
|
||||
stmt = select(DatabaseJob).where(DatabaseJob.id == cron_job.job_id)
|
||||
result = await db.execute(stmt)
|
||||
queried_job = result.scalars().first()
|
||||
|
||||
if not queried_job:
|
||||
return JSONResponse(status_code=404, content={"error": "Related job not found"})
|
||||
queried_job = query("SELECT * FROM jobs WHERE id = ?", (cron_job.job_id,))
|
||||
|
||||
scheduler.add_job(
|
||||
insert_job_from_cron_job,
|
||||
get_cron_job_trigger(cron_job.cron_expression),
|
||||
id=cron_job.id,
|
||||
args=[queried_job],
|
||||
args=[queried_job[0]],
|
||||
)
|
||||
|
||||
return JSONResponse(content={"message": "Cron job scheduled successfully."})
|
||||
@@ -232,7 +200,7 @@ async def delete_cron_job_request(request: DeleteCronJob):
|
||||
content={"error": "Cron job id is required."}, status_code=400
|
||||
)
|
||||
|
||||
await delete_cron_job(request.id, request.user_email)
|
||||
delete_cron_job(request.id, request.user_email)
|
||||
scheduler.remove_job(request.id)
|
||||
|
||||
return JSONResponse(content={"message": "Cron job deleted successfully."})
|
||||
@@ -241,7 +209,7 @@ async def delete_cron_job_request(request: DeleteCronJob):
|
||||
@job_router.get("/cron-jobs")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def get_cron_jobs_request(user: User = Depends(get_current_user)):
|
||||
cron_jobs = await get_cron_jobs(user.email)
|
||||
cron_jobs = get_cron_jobs(user.email)
|
||||
return JSONResponse(content=jsonable_encoder(cron_jobs))
|
||||
|
||||
|
||||
|
||||
@@ -174,9 +174,7 @@ async def scrape(
|
||||
|
||||
for page in pages:
|
||||
elements.append(
|
||||
await collect_scraped_elements(
|
||||
page, xpaths, job_options.get("return_html", False)
|
||||
)
|
||||
await collect_scraped_elements(page, xpaths, job_options["return_html"])
|
||||
)
|
||||
|
||||
return elements
|
||||
|
||||
@@ -28,9 +28,7 @@ def clean_job_format(jobs: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"xpath": value.get("xpath", ""),
|
||||
"text": text,
|
||||
"user": job.get("user", ""),
|
||||
"time_created": job.get(
|
||||
"time_created", ""
|
||||
).isoformat(),
|
||||
"time_created": job.get("time_created", ""),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# PDM
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler = BackgroundScheduler()
|
||||
|
||||
@@ -3,11 +3,9 @@ import logging
|
||||
|
||||
# PDM
|
||||
from fastapi import Depends, APIRouter
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.auth.schemas import User
|
||||
from api.backend.database.base import get_db
|
||||
from api.backend.auth.auth_utils import get_current_user
|
||||
from api.backend.routers.handle_exceptions import handle_exceptions
|
||||
from api.backend.database.queries.statistics.statistic_queries import (
|
||||
@@ -22,16 +20,12 @@ stats_router = APIRouter()
|
||||
|
||||
@stats_router.get("/statistics/get-average-element-per-link")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def get_average_element_per_link(
|
||||
user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
return await average_elements_per_link(db, user.email)
|
||||
async def get_average_element_per_link(user: User = Depends(get_current_user)):
|
||||
return await average_elements_per_link(user.email)
|
||||
|
||||
|
||||
@stats_router.get("/statistics/get-average-jobs-per-day")
|
||||
@handle_exceptions(logger=LOG)
|
||||
async def average_jobs_per_day(
|
||||
user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
data = await get_jobs_per_day(db, user.email)
|
||||
async def average_jobs_per_day(user: User = Depends(get_current_user)):
|
||||
data = await get_jobs_per_day(user.email)
|
||||
return data
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
# STL
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any, Generator, AsyncGenerator
|
||||
import sqlite3
|
||||
from typing import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
# PDM
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from proxy import Proxy
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
# LOCAL
|
||||
from api.backend.app import app
|
||||
from api.backend.database.base import get_db
|
||||
from api.backend.database.models import Base
|
||||
from api.backend.database.schema import INIT_QUERY
|
||||
from api.backend.tests.constants import TEST_DB_PATH
|
||||
|
||||
|
||||
@@ -27,6 +21,18 @@ def running_proxy():
|
||||
proxy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def patch_database_path():
|
||||
with patch("api.backend.database.common.DATABASE_PATH", TEST_DB_PATH):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def patch_recordings_enabled():
|
||||
with patch("api.backend.job.scraping.scraping.RECORDINGS_ENABLED", False):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_db_path() -> str:
|
||||
return TEST_DB_PATH
|
||||
@@ -40,69 +46,18 @@ def test_db(test_db_path: str) -> Generator[str, None, None]:
|
||||
if os.path.exists(test_db_path):
|
||||
os.remove(test_db_path)
|
||||
|
||||
# Create async engine for test database
|
||||
test_db_url = f"sqlite+aiosqlite:///{test_db_path}"
|
||||
engine = create_async_engine(test_db_url, echo=False)
|
||||
conn = sqlite3.connect(test_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
async def setup_db():
|
||||
async with engine.begin() as conn:
|
||||
# Create tables
|
||||
# LOCAL
|
||||
from api.backend.database.models import Base
|
||||
for query in INIT_QUERY.strip().split(";"):
|
||||
query = query.strip()
|
||||
if query:
|
||||
cursor.execute(query)
|
||||
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Run setup
|
||||
asyncio.run(setup_db())
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
yield test_db_path
|
||||
|
||||
if os.path.exists(test_db_path):
|
||||
os.remove(test_db_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def test_engine():
|
||||
test_db_url = f"sqlite+aiosqlite:///{TEST_DB_PATH}"
|
||||
engine = create_async_engine(test_db_url, poolclass=NullPool)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def db_session(test_engine: Any) -> AsyncGenerator[AsyncSession, None]:
|
||||
async_session = async_sessionmaker(
|
||||
bind=test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Truncate all tables after each test
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
await session.execute(text(f"DELETE FROM {table.name}"))
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def override_get_db(db_session: AsyncSession):
|
||||
async def _override() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
return _override
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def client(override_get_db: Any) -> AsyncGenerator[AsyncClient, None]:
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@@ -1,65 +1,45 @@
|
||||
# STL
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
# PDM
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# LOCAL
|
||||
from api.backend.app import app
|
||||
from api.backend.schemas.job import DownloadJob
|
||||
from api.backend.database.models import Job
|
||||
from api.backend.tests.factories.job_factory import create_completed_job
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
mocked_job = create_completed_job().model_dump()
|
||||
mock_results = [mocked_job]
|
||||
mocked_random_int = 123456
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download(client: AsyncClient, db_session: AsyncSession):
|
||||
# Insert a test job into the DB
|
||||
job_id = "test-job-id"
|
||||
test_job = Job(
|
||||
id=job_id,
|
||||
url="https://example.com",
|
||||
elements=[],
|
||||
user="test@example.com",
|
||||
time_created=datetime.now(timezone.utc),
|
||||
result=[
|
||||
{
|
||||
"https://example.com": {
|
||||
"element_name": [{"xpath": "//div", "text": "example"}]
|
||||
}
|
||||
}
|
||||
],
|
||||
status="Completed",
|
||||
chat=None,
|
||||
job_options={},
|
||||
agent_mode=False,
|
||||
prompt="",
|
||||
favorite=False,
|
||||
)
|
||||
db_session.add(test_job)
|
||||
await db_session.commit()
|
||||
@patch("api.backend.job.job_router.query")
|
||||
@patch("api.backend.job.job_router.random.randint")
|
||||
async def test_download(mock_randint: AsyncMock, mock_query: AsyncMock):
|
||||
# Ensure the mock returns immediately
|
||||
mock_query.return_value = mock_results
|
||||
mock_randint.return_value = mocked_random_int
|
||||
|
||||
# Force predictable randint
|
||||
random.seed(0)
|
||||
# Create a DownloadJob instance
|
||||
download_job = DownloadJob(ids=[mocked_job["id"]], job_format="csv")
|
||||
|
||||
# Build request
|
||||
download_job = DownloadJob(ids=[job_id], job_format="csv")
|
||||
response = await client.post("/download", json=download_job.model_dump())
|
||||
# Make a POST request to the /download endpoint
|
||||
response = client.post("/download", json=download_job.model_dump())
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200
|
||||
assert response.headers["Content-Disposition"] == "attachment; filename=export.csv"
|
||||
|
||||
# Validate CSV contents
|
||||
# Check the content of the CSV
|
||||
csv_content = response.content.decode("utf-8")
|
||||
lines = csv_content.strip().split("\n")
|
||||
|
||||
assert (
|
||||
lines[0].strip()
|
||||
== '"id","url","element_name","xpath","text","user","time_created"'
|
||||
expected_csv = (
|
||||
f'"id","url","element_name","xpath","text","user","time_created"\r\n'
|
||||
f'"{mocked_job["id"]}-{mocked_random_int}","https://example.com","element_name","//div","example",'
|
||||
f'"{mocked_job["user"]}","{mocked_job["time_created"]}"\r\n'
|
||||
)
|
||||
assert '"https://example.com"' in lines[1]
|
||||
assert '"element_name"' in lines[1]
|
||||
assert '"//div"' in lines[1]
|
||||
assert '"example"' in lines[1]
|
||||
assert csv_content == expected_csv
|
||||
|
||||
@@ -5,17 +5,15 @@ from datetime import datetime
|
||||
|
||||
# PDM
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
from fastapi.testclient import TestClient
|
||||
from playwright.async_api import Route, Cookie, async_playwright
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.app import app
|
||||
from api.backend.job.models import Proxy, Element, JobOptions
|
||||
from api.backend.schemas.job import Job
|
||||
from api.backend.database.models import Job as JobModel
|
||||
from api.backend.database.common import query
|
||||
from api.backend.job.scraping.scraping import scrape
|
||||
from api.backend.job.scraping.add_custom import add_custom_items
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -70,7 +68,7 @@ async def test_add_custom_items():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxies(client: AsyncClient, db_session: AsyncSession):
|
||||
async def test_proxies():
|
||||
job = Job(
|
||||
url="https://example.com",
|
||||
elements=[Element(xpath="//div", name="test")],
|
||||
@@ -86,22 +84,14 @@ async def test_proxies(client: AsyncClient, db_session: AsyncSession):
|
||||
time_created=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
response = await client.post("/submit-scrape-job", json=job.model_dump())
|
||||
response = client.post("/submit-scrape-job", json=job.model_dump())
|
||||
assert response.status_code == 200
|
||||
|
||||
stmt = select(JobModel)
|
||||
result = await db_session.execute(stmt)
|
||||
jobs = result.scalars().all()
|
||||
jobs = query("SELECT * FROM jobs")
|
||||
job = jobs[0]
|
||||
|
||||
assert len(jobs) > 0
|
||||
job_from_db = jobs[0]
|
||||
|
||||
job_dict = job_from_db.__dict__
|
||||
job_dict.pop("_sa_instance_state", None)
|
||||
|
||||
assert job_dict is not None
|
||||
print(job_dict)
|
||||
assert job_dict["job_options"]["proxies"] == [
|
||||
assert job is not None
|
||||
assert job["job_options"]["proxies"] == [
|
||||
{
|
||||
"server": "127.0.0.1:8080",
|
||||
"username": "user",
|
||||
@@ -109,9 +99,12 @@ async def test_proxies(client: AsyncClient, db_session: AsyncSession):
|
||||
}
|
||||
]
|
||||
|
||||
# Verify the job was stored correctly in the database
|
||||
assert job_dict["url"] == "https://example.com"
|
||||
assert job_dict["status"] == "Queued"
|
||||
assert len(job_dict["elements"]) == 1
|
||||
assert job_dict["elements"][0]["xpath"] == "//div"
|
||||
assert job_dict["elements"][0]["name"] == "test"
|
||||
response = await scrape(
|
||||
id=job["id"],
|
||||
url=job["url"],
|
||||
xpaths=[Element(**e) for e in job["elements"]],
|
||||
job_options=job["job_options"],
|
||||
)
|
||||
|
||||
example_response = response[0]["https://example.com/"]
|
||||
assert example_response is not {}
|
||||
|
||||
@@ -12,6 +12,7 @@ from api.backend.job import update_job, get_queued_job
|
||||
from api.backend.job.models import Element
|
||||
from api.backend.worker.logger import LOG
|
||||
from api.backend.ai.agent.agent import scrape_with_agent
|
||||
from api.backend.database.startup import init_database
|
||||
from api.backend.worker.constants import (
|
||||
TO,
|
||||
EMAIL,
|
||||
@@ -123,6 +124,8 @@ async def process_job():
|
||||
async def main():
|
||||
LOG.info("Starting job worker...")
|
||||
|
||||
init_database()
|
||||
|
||||
RECORDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
while True:
|
||||
|
||||
@@ -3,7 +3,7 @@ FROM python:3.10.12-slim as pybuilder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
apt-get install -y x11vnc xvfb uvicorn wget gnupg supervisor libgl1 libglx-mesa0 libglx0 vainfo libva-dev libva-glx2 libva-drm2 ffmpeg pkg-config default-libmysqlclient-dev gcc && \
|
||||
apt-get install -y x11vnc xvfb uvicorn wget gnupg supervisor libgl1 libglx-mesa0 libglx0 vainfo libva-dev libva-glx2 libva-drm2 ffmpeg && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
apt-get remove -y curl && \
|
||||
apt-get autoremove -y && \
|
||||
@@ -37,9 +37,6 @@ RUN touch /project/app/data/database.db
|
||||
|
||||
EXPOSE 5900
|
||||
|
||||
COPY alembic /project/app/alembic
|
||||
COPY alembic.ini /project/app/alembic.ini
|
||||
|
||||
COPY start.sh /project/app/start.sh
|
||||
|
||||
CMD [ "supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf" ]
|
||||
@@ -6,7 +6,7 @@ WORKDIR /app
|
||||
COPY package.json yarn.lock ./
|
||||
|
||||
# Install dependencies in a separate layer
|
||||
RUN yarn install --frozen-lockfile --network-timeout 600000
|
||||
RUN yarn install --frozen-lockfile
|
||||
|
||||
# Copy the rest of the application
|
||||
COPY tsconfig.json /app/tsconfig.json
|
||||
|
||||
@@ -15,7 +15,7 @@ type: application
|
||||
# This is the chart version. This version number should be incremented each time you make changes
|
||||
# to the chart and its templates, including the app version.
|
||||
# Versions are expected to follow Semantic Versioning (https://semver.org/)
|
||||
version: 1.1.7
|
||||
version: 1.1.1
|
||||
|
||||
# This is the version number of the application being deployed. This version number should be
|
||||
# incremented each time you make changes to the application. Versions are not expected to
|
||||
|
||||
2
next-env.d.ts
vendored
2
next-env.d.ts
vendored
@@ -2,4 +2,4 @@
|
||||
/// <reference types="next/image-types/global" />
|
||||
|
||||
// NOTE: This file should not be edited
|
||||
// see https://nextjs.org/docs/pages/building-your-application/configuring/typescript for more information.
|
||||
// see https://nextjs.org/docs/basic-features/typescript for more information.
|
||||
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"asyncio>=3.4.3",
|
||||
"aiohttp>=3.9.5",
|
||||
"bs4>=0.0.2",
|
||||
"lxml>=5.2.2",
|
||||
"lxml[html_clean]>=5.2.2",
|
||||
"lxml-stubs>=0.5.1",
|
||||
"fake-useragent>=1.5.1",
|
||||
"requests-html>=0.10.0",
|
||||
@@ -24,6 +24,7 @@ dependencies = [
|
||||
"python-keycloak>=4.2.0",
|
||||
"fastapi-keycloak>=1.0.11",
|
||||
"pymongo>=4.8.0",
|
||||
"motor[asyncio]>=3.5.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"selenium-wire>=5.1.0",
|
||||
@@ -42,14 +43,6 @@ dependencies = [
|
||||
"camoufox>=0.4.11",
|
||||
"html2text>=2025.4.15",
|
||||
"proxy-py>=2.4.10",
|
||||
"browserforge==1.2.1",
|
||||
"sqlalchemy>=2.0.41",
|
||||
"aiosqlite>=0.21.0",
|
||||
"alembic>=1.16.4",
|
||||
"asyncpg>=0.30.0",
|
||||
"aiomysql>=0.2.0",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"mysqlclient>=2.2.7",
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
readme = "README.md"
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { ExpandedTableInput } from "@/components/common/expanded-table-input";
|
||||
import { UploadFile } from "@/components/common/upload-file";
|
||||
import { useImportJobConfig } from "@/hooks/use-import-job-config";
|
||||
import { RawJobOptions } from "@/types";
|
||||
import {
|
||||
Code as CodeIcon,
|
||||
@@ -28,7 +26,6 @@ import {
|
||||
useTheme,
|
||||
} from "@mui/material";
|
||||
import { Dispatch, SetStateAction, useEffect, useState } from "react";
|
||||
import { toast } from "react-toastify";
|
||||
|
||||
export type AdvancedJobOptionsDialogProps = {
|
||||
open: boolean;
|
||||
@@ -46,7 +43,6 @@ export const AdvancedJobOptionsDialog = ({
|
||||
multiPageScrapeEnabled = true,
|
||||
}: AdvancedJobOptionsDialogProps) => {
|
||||
const theme = useTheme();
|
||||
const { handleUploadFile } = useImportJobConfig();
|
||||
const [localJobOptions, setLocalJobOptions] =
|
||||
useState<RawJobOptions>(jobOptions);
|
||||
|
||||
@@ -73,18 +69,6 @@ export const AdvancedJobOptionsDialog = ({
|
||||
onClose();
|
||||
};
|
||||
|
||||
const onUploadFile = async (file: File) => {
|
||||
const errorOccured = await handleUploadFile(file);
|
||||
if (errorOccured) {
|
||||
handleClose();
|
||||
toast.error("Failed to upload job config");
|
||||
return;
|
||||
} else {
|
||||
handleClose();
|
||||
toast.success("Job config uploaded successfully");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
@@ -115,18 +99,11 @@ export const AdvancedJobOptionsDialog = ({
|
||||
<Typography variant="h6" component="div">
|
||||
Advanced Job Options
|
||||
</Typography>
|
||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<UploadFile
|
||||
message="Upload Job Config"
|
||||
fileTypes={["application/json"]}
|
||||
onUploadFile={onUploadFile}
|
||||
/>
|
||||
<Settings
|
||||
sx={{
|
||||
color: theme.palette.primary.contrastText,
|
||||
}}
|
||||
/>
|
||||
</Box>
|
||||
<Settings
|
||||
sx={{
|
||||
color: theme.palette.primary.contrastText,
|
||||
}}
|
||||
/>
|
||||
</DialogTitle>
|
||||
|
||||
<DialogContent
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export * from "./upload-file";
|
||||
@@ -1,34 +0,0 @@
|
||||
import { Box, Button, Typography } from "@mui/material";
|
||||
|
||||
export type UploadFileProps = {
|
||||
message: string;
|
||||
fileTypes?: string[];
|
||||
onUploadFile: (file: File) => void;
|
||||
};
|
||||
|
||||
export const UploadFile = ({
|
||||
message,
|
||||
fileTypes,
|
||||
onUploadFile,
|
||||
}: UploadFileProps) => {
|
||||
const handleUploadFile = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (file) {
|
||||
onUploadFile(file);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Button variant="contained" component="label">
|
||||
<Typography>{message}</Typography>
|
||||
<input
|
||||
type="file"
|
||||
hidden
|
||||
onChange={handleUploadFile}
|
||||
accept={fileTypes?.join(",")}
|
||||
/>
|
||||
</Button>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -1,18 +1,18 @@
|
||||
import StarIcon from "@mui/icons-material/Star";
|
||||
import React from "react";
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Checkbox,
|
||||
Tooltip,
|
||||
IconButton,
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableRow,
|
||||
Tooltip,
|
||||
Box,
|
||||
Checkbox,
|
||||
Button,
|
||||
} from "@mui/material";
|
||||
import router from "next/router";
|
||||
import { Job } from "../../types";
|
||||
import StarIcon from "@mui/icons-material/Star";
|
||||
|
||||
interface stateProps {
|
||||
selectedJobs: Set<string>;
|
||||
@@ -21,12 +21,7 @@ interface stateProps {
|
||||
|
||||
interface Props {
|
||||
onSelectJob: (job: string) => void;
|
||||
onNavigate: (
|
||||
id: string,
|
||||
elements: Object[],
|
||||
url: string,
|
||||
options: any
|
||||
) => void;
|
||||
onNavigate: (elements: Object[], url: string, options: any) => void;
|
||||
onFavorite: (ids: string[], field: string, value: any) => void;
|
||||
stateProps: stateProps;
|
||||
}
|
||||
@@ -92,29 +87,11 @@ export const Favorites = ({
|
||||
</TableCell>
|
||||
<TableCell sx={{ maxWidth: 100, overflow: "auto" }}>
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (row.agent_mode) {
|
||||
router.push({
|
||||
pathname: "/agent",
|
||||
query: {
|
||||
url: row.url,
|
||||
prompt: row.prompt,
|
||||
job_options: JSON.stringify(row.job_options),
|
||||
id: row.id,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
onNavigate(row.id, row.elements, row.url, row.job_options);
|
||||
}
|
||||
}}
|
||||
size="small"
|
||||
sx={{
|
||||
minWidth: 0,
|
||||
padding: "4px 8px",
|
||||
fontSize: "0.625rem",
|
||||
}}
|
||||
onClick={() =>
|
||||
onNavigate(row.elements, row.url, row.job_options)
|
||||
}
|
||||
>
|
||||
Rerun
|
||||
Run
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
"use client";
|
||||
import { useExportJobConfig } from "@/hooks/use-export-job-config";
|
||||
import {
|
||||
AutoAwesome,
|
||||
Image,
|
||||
Settings,
|
||||
VideoCameraBack,
|
||||
} from "@mui/icons-material";
|
||||
import { AutoAwesome, Image, VideoCameraBack } from "@mui/icons-material";
|
||||
import StarIcon from "@mui/icons-material/Star";
|
||||
import {
|
||||
Box,
|
||||
@@ -36,12 +30,7 @@ interface Props {
|
||||
colors: stringMap;
|
||||
onSelectJob: (job: string) => void;
|
||||
onDownload: (job: string[]) => void;
|
||||
onNavigate: (
|
||||
id: string,
|
||||
elements: Object[],
|
||||
url: string,
|
||||
options: any
|
||||
) => void;
|
||||
onNavigate: (elements: Object[], url: string, options: any) => void;
|
||||
onFavorite: (ids: string[], field: string, value: any) => void;
|
||||
onJobClick: (job: Job) => void;
|
||||
stateProps: stateProps;
|
||||
@@ -57,7 +46,6 @@ export const JobQueue = ({
|
||||
onJobClick,
|
||||
}: Props) => {
|
||||
const { selectedJobs, filteredJobs } = stateProps;
|
||||
const { exportJobConfig } = useExportJobConfig();
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
@@ -128,17 +116,6 @@ export const JobQueue = ({
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Tooltip title="Export Job Configuration">
|
||||
<span>
|
||||
<IconButton
|
||||
onClick={() => {
|
||||
exportJobConfig(row);
|
||||
}}
|
||||
>
|
||||
<Settings />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
{row.job_options.collect_media && (
|
||||
<Tooltip title="View Media">
|
||||
<span>
|
||||
@@ -237,16 +214,10 @@ export const JobQueue = ({
|
||||
url: row.url,
|
||||
prompt: row.prompt,
|
||||
job_options: JSON.stringify(row.job_options),
|
||||
id: row.id,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
onNavigate(
|
||||
row.id,
|
||||
row.elements,
|
||||
row.url,
|
||||
row.job_options
|
||||
);
|
||||
onNavigate(row.elements, row.url, row.job_options);
|
||||
}
|
||||
}}
|
||||
size="small"
|
||||
|
||||
@@ -47,16 +47,10 @@ export const JobTable: React.FC<JobTableProps> = ({ jobs, setJobs }) => {
|
||||
setJobDownloadDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleNavigate = (
|
||||
id: string,
|
||||
elements: Object[],
|
||||
url: string,
|
||||
options: any
|
||||
) => {
|
||||
const handleNavigate = (elements: Object[], url: string, options: any) => {
|
||||
router.push({
|
||||
pathname: "/",
|
||||
query: {
|
||||
id,
|
||||
elements: JSON.stringify(elements),
|
||||
url: url,
|
||||
job_options: JSON.stringify(options),
|
||||
|
||||
@@ -13,44 +13,21 @@ import { useJobSubmitterProvider } from "./provider";
|
||||
|
||||
export const JobSubmitter = () => {
|
||||
const router = useRouter();
|
||||
const { job_options, id } = router.query;
|
||||
const { job_options } = router.query;
|
||||
const { user } = useUser();
|
||||
|
||||
const { submitJob, loading, error } = useSubmitJob();
|
||||
const {
|
||||
jobId,
|
||||
setJobId,
|
||||
submittedURL,
|
||||
rows,
|
||||
siteMap,
|
||||
setSiteMap,
|
||||
jobOptions,
|
||||
setJobOptions,
|
||||
} = useJobSubmitterProvider();
|
||||
const { submittedURL, rows, siteMap, setSiteMap, jobOptions, setJobOptions } =
|
||||
useJobSubmitterProvider();
|
||||
|
||||
useEffect(() => {
|
||||
if (job_options) {
|
||||
parseJobOptions(
|
||||
id as string,
|
||||
job_options as string,
|
||||
setJobOptions,
|
||||
setSiteMap,
|
||||
setJobId
|
||||
);
|
||||
parseJobOptions(job_options as string, setJobOptions, setSiteMap);
|
||||
}
|
||||
}, [job_options]);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
await submitJob(
|
||||
submittedURL,
|
||||
rows,
|
||||
user,
|
||||
jobOptions,
|
||||
siteMap,
|
||||
false,
|
||||
null,
|
||||
jobId
|
||||
);
|
||||
await submitJob(submittedURL, rows, user, jobOptions, siteMap, false, null);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -10,8 +10,6 @@ import React, {
|
||||
} from "react";
|
||||
|
||||
type JobSubmitterProviderType = {
|
||||
jobId: string;
|
||||
setJobId: Dispatch<React.SetStateAction<string>>;
|
||||
submittedURL: string;
|
||||
setSubmittedURL: Dispatch<React.SetStateAction<string>>;
|
||||
rows: Element[];
|
||||
@@ -38,7 +36,6 @@ const JobSubmitterProvider = createContext<JobSubmitterProviderType>(
|
||||
);
|
||||
|
||||
export const Provider = ({ children }: PropsWithChildren) => {
|
||||
const [jobId, setJobId] = useState<string>("");
|
||||
const [submittedURL, setSubmittedURL] = useState<string>("");
|
||||
const [rows, setRows] = useState<Element[]>([]);
|
||||
const [results, setResults] = useState<Result>({});
|
||||
@@ -58,8 +55,6 @@ export const Provider = ({ children }: PropsWithChildren) => {
|
||||
|
||||
const value: JobSubmitterProviderType = useMemo(
|
||||
() => ({
|
||||
jobId,
|
||||
setJobId,
|
||||
submittedURL,
|
||||
setSubmittedURL,
|
||||
rows,
|
||||
@@ -81,7 +76,6 @@ export const Provider = ({ children }: PropsWithChildren) => {
|
||||
closeSnackbar,
|
||||
}),
|
||||
[
|
||||
jobId,
|
||||
submittedURL,
|
||||
rows,
|
||||
results,
|
||||
|
||||
@@ -15,14 +15,14 @@ export const useAdvancedJobOptions = () => {
|
||||
};
|
||||
|
||||
const router = useRouter();
|
||||
const { job_options, job_id } = router.query;
|
||||
const { job_options } = router.query;
|
||||
|
||||
const [jobOptions, setJobOptions] =
|
||||
useState<RawJobOptions>(initialJobOptions);
|
||||
|
||||
useEffect(() => {
|
||||
if (job_options) {
|
||||
parseJobOptions(job_id as string, job_options as string, setJobOptions);
|
||||
parseJobOptions(job_options as string, setJobOptions);
|
||||
}
|
||||
}, [job_options]);
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { Job } from "@/types";
|
||||
|
||||
export const useExportJobConfig = () => {
|
||||
const exportJobConfig = async (job: Job) => {
|
||||
const jobConfig = {
|
||||
url: job.url,
|
||||
prompt: job.prompt,
|
||||
job_options: job.job_options,
|
||||
elements: job.elements,
|
||||
agent_mode: job.agent_mode,
|
||||
};
|
||||
|
||||
const jobConfigString = JSON.stringify(jobConfig);
|
||||
const blob = new Blob([jobConfigString], { type: "application/json" });
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.style.display = "none";
|
||||
a.href = url;
|
||||
a.download = `job_${job.id}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
window.URL.revokeObjectURL(url);
|
||||
document.body.removeChild(a);
|
||||
};
|
||||
|
||||
return { exportJobConfig };
|
||||
};
|
||||
@@ -1,83 +0,0 @@
|
||||
import { useJobSubmitterProvider } from "@/components/submit/job-submitter/provider";
|
||||
import { useRouter } from "next/router";
|
||||
import { toast } from "react-toastify";
|
||||
|
||||
export const useImportJobConfig = () => {
|
||||
const router = useRouter();
|
||||
const { setJobOptions, setSiteMap, setSubmittedURL, setRows } =
|
||||
useJobSubmitterProvider();
|
||||
|
||||
const handleUploadFile = (file: File): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onerror = () => {
|
||||
toast.error("Failed to read file");
|
||||
resolve(true);
|
||||
};
|
||||
|
||||
reader.onload = (e) => {
|
||||
const result = e.target?.result as string;
|
||||
|
||||
if (!result.includes("url")) {
|
||||
toast.error("Invalid job config: missing url");
|
||||
resolve(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.includes("job_options")) {
|
||||
toast.error("Invalid job config: missing job_options");
|
||||
resolve(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.includes("elements")) {
|
||||
toast.error("Invalid job config: missing elements");
|
||||
resolve(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.includes("site_map")) {
|
||||
toast.error("Invalid job config: missing site_map");
|
||||
resolve(true);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const jobConfig = JSON.parse(result);
|
||||
|
||||
if (jobConfig.agent_mode) {
|
||||
router.push({
|
||||
pathname: "/agent",
|
||||
query: {
|
||||
url: jobConfig.url,
|
||||
prompt: jobConfig.prompt,
|
||||
job_options: JSON.stringify(jobConfig.job_options),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
jobConfig.job_options &&
|
||||
Array.isArray(jobConfig.job_options.proxies)
|
||||
) {
|
||||
jobConfig.job_options.proxies = "";
|
||||
}
|
||||
|
||||
setJobOptions(jobConfig.job_options || {});
|
||||
setSiteMap(jobConfig.site_map);
|
||||
setSubmittedURL(jobConfig.url || "");
|
||||
setRows(jobConfig.elements || []);
|
||||
resolve(false);
|
||||
} catch (error) {
|
||||
toast.error("Failed to parse job config");
|
||||
resolve(true);
|
||||
}
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
});
|
||||
};
|
||||
|
||||
return { handleUploadFile };
|
||||
};
|
||||
@@ -25,8 +25,7 @@ export const useSubmitJob = () => {
|
||||
jobOptions: RawJobOptions,
|
||||
siteMap: SiteMap | null,
|
||||
agentMode: boolean,
|
||||
prompt: string | null,
|
||||
id?: string
|
||||
prompt: string | null
|
||||
) => {
|
||||
if (!validateURL(submittedURL)) {
|
||||
setIsValidUrl(false);
|
||||
@@ -62,8 +61,7 @@ export const useSubmitJob = () => {
|
||||
customCookies,
|
||||
siteMap,
|
||||
agentMode,
|
||||
prompt || undefined,
|
||||
id
|
||||
prompt || undefined
|
||||
)
|
||||
.then(async (response) => {
|
||||
if (!response.ok) {
|
||||
@@ -82,10 +80,7 @@ export const useSubmitJob = () => {
|
||||
setSnackbarOpen(true);
|
||||
})
|
||||
.catch((error) => {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "An error occurred.";
|
||||
console.log(errorMessage);
|
||||
setSnackbarMessage(errorMessage);
|
||||
setSnackbarMessage(error || "An error occurred.");
|
||||
setSnackbarSeverity("error");
|
||||
setSnackbarOpen(true);
|
||||
})
|
||||
|
||||
@@ -3,11 +3,9 @@ import { Dispatch, SetStateAction } from "react";
|
||||
import { RawJobOptions, SiteMap } from "@/types";
|
||||
|
||||
export const parseJobOptions = (
|
||||
id: string,
|
||||
job_options: string,
|
||||
setJobOptions: Dispatch<SetStateAction<RawJobOptions>>,
|
||||
setSiteMap?: Dispatch<SetStateAction<SiteMap | null>>,
|
||||
setJobId?: Dispatch<SetStateAction<string>>
|
||||
setSiteMap?: Dispatch<SetStateAction<SiteMap | null>>
|
||||
) => {
|
||||
if (job_options) {
|
||||
const jsonOptions = JSON.parse(job_options as string);
|
||||
@@ -49,10 +47,6 @@ export const parseJobOptions = (
|
||||
newJobOptions.return_html = true;
|
||||
}
|
||||
|
||||
if (id && setJobId) {
|
||||
setJobId(id);
|
||||
}
|
||||
|
||||
setJobOptions(newJobOptions);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -21,16 +21,15 @@ export default async function handler(
|
||||
}
|
||||
);
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
if (response.status === 500) {
|
||||
res.status(500).json({ error: result.error });
|
||||
if (!response.ok) {
|
||||
throw new Error(`Error: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
res.status(200).json(result);
|
||||
} catch (error) {
|
||||
console.error("Error submitting scrape job:", error);
|
||||
res.status(500).json({ error: error });
|
||||
res.status(500).json({ error: "Internal Server Error" });
|
||||
}
|
||||
} else {
|
||||
res.setHeader("Allow", ["POST"]);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
export const download = async (ids: string[], jobFormat: string) => {
|
||||
const response = await fetch("/api/download", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ data: { ids, job_format: jobFormat } }),
|
||||
});
|
||||
|
||||
|
||||
@@ -9,15 +9,14 @@ export const submitJob = async (
|
||||
customCookies: any,
|
||||
siteMap: SiteMap | null,
|
||||
agentMode: boolean = false,
|
||||
prompt?: string,
|
||||
id?: string
|
||||
prompt?: string
|
||||
) => {
|
||||
console.log(user);
|
||||
return await fetch(`/api/submit-scrape-job`, {
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
data: {
|
||||
id,
|
||||
url: submittedURL,
|
||||
elements: rows,
|
||||
user: user?.email,
|
||||
|
||||
Reference in New Issue
Block a user