mirror of
https://github.com/jaypyles/Scraperr.git
synced 2025-12-18 13:45:35 +00:00
Feat/swap to sqlalchemy (#99)
* chore: wip swap to sqlalchemy * feat: swap to sqlalchemy * feat: swap to sqlalchemy * feat: swap to sqlalchemy * feat: swap to sqlalchemy
This commit is contained in:
@@ -8,12 +8,15 @@ from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from passlib.context import CryptContext
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# LOCAL
|
||||
from api.backend.auth.schemas import User, UserInDB, TokenData
|
||||
from api.backend.database.common import query
|
||||
from api.backend.database.base import get_db
|
||||
from api.backend.database.models import User as UserModel
|
||||
|
||||
LOG = logging.getLogger("Auth")
|
||||
|
||||
@@ -37,18 +40,24 @@ def get_password_hash(password: str):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def get_user(email: str):
|
||||
user_query = "SELECT * FROM users WHERE email = ?"
|
||||
user = query(user_query, (email,))[0]
|
||||
async def get_user(session: AsyncSession, email: str) -> UserInDB | None:
|
||||
stmt = select(UserModel).where(UserModel.email == email)
|
||||
result = await session.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
return
|
||||
return None
|
||||
|
||||
return UserInDB(**user)
|
||||
return UserInDB(
|
||||
email=str(user.email),
|
||||
hashed_password=str(user.hashed_password),
|
||||
full_name=str(user.full_name),
|
||||
disabled=bool(user.disabled),
|
||||
)
|
||||
|
||||
|
||||
async def authenticate_user(email: str, password: str):
|
||||
user = await get_user(email)
|
||||
async def authenticate_user(email: str, password: str, db: AsyncSession):
|
||||
user = await get_user(db, email)
|
||||
|
||||
if not user:
|
||||
return False
|
||||
@@ -74,7 +83,9 @@ def create_access_token(
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
):
|
||||
LOG.debug(f"Getting current user with token: {token}")
|
||||
|
||||
if not token:
|
||||
@@ -82,7 +93,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
return EMPTY_USER
|
||||
|
||||
if len(token.split(".")) != 3:
|
||||
LOG.error(f"Malformed token: {token}")
|
||||
LOG.debug(f"Malformed token: {token}")
|
||||
return EMPTY_USER
|
||||
|
||||
try:
|
||||
@@ -117,7 +128,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
LOG.error(f"Exception occurred: {e}")
|
||||
return EMPTY_USER
|
||||
|
||||
user = await get_user(email=token_data.email or "")
|
||||
user = await get_user(db, email=token_data.email or "")
|
||||
|
||||
if user is None:
|
||||
return EMPTY_USER
|
||||
@@ -125,7 +136,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
return user
|
||||
|
||||
|
||||
async def require_user(token: str = Depends(oauth2_scheme)):
|
||||
async def require_user(db: AsyncSession, token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -150,7 +161,7 @@ async def require_user(token: str = Depends(oauth2_scheme)):
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = await get_user(email=token_data.email or "")
|
||||
user = await get_user(db, email=token_data.email or "")
|
||||
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
Reference in New Issue
Block a user