Files
picrinth-server/src/api/utils.py

114 lines
3.1 KiB
Python

from datetime import datetime, timedelta, timezone
from typing import Annotated
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
# from passlib.context import CryptContext
from psycopg2._psycopg import connection
import db.groups
import db.users
import settings.startup_settings as startup_settings
from api.models import Group, TokenData, User
from db.internal import get_db_connection
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def verify_password(plain_password: str, hashed_password: str):
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
def get_password_hash(password: str):
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def decode_token(token):
return jwt.decode(token, startup_settings.secret_key, algorithms=[startup_settings.algorithm])
def encode_token(payload):
return jwt.encode(payload, startup_settings.secret_key, algorithm=startup_settings.algorithm)
def authenticate_user(
conn: connection,
username: str,
user_password: str
):
if not user_password:
return False
db_user_password = db.users.get_user_password(conn, username)
if db_user_password is None:
return False
if not verify_password(user_password, db_user_password):
return False
return True
def create_access_token(
data: dict,
expires_delta: timedelta
):
encode_payload = data.copy()
expire_moment = datetime.now(timezone.utc) + expires_delta
encode_payload.update({"exp": expire_moment})
encoded_jwt = encode_token(encode_payload)
return encoded_jwt
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
conn: Annotated[connection, Depends(get_db_connection)]
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
try:
payload = decode_token(token)
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = User()
user_data = db.users.get_user(conn, token_data.username)
if user_data is None:
raise credentials_exception
user.fill(user_data)
if user.disabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is disabled"
)
return user
def get_group_by_name(
conn: connection,
groupname: str
):
group_exception = HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No such group"
)
group = Group()
group_data = db.groups.get_group(conn, groupname)
if group_data is None:
raise group_exception
group.fill(group_data)
return group