fixed auth, added more endpoints and config saving
This commit is contained in:
40
src/api/anon.py
Normal file
40
src/api/anon.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from psycopg2._psycopg import connection
|
||||||
|
|
||||||
|
import db.users as db
|
||||||
|
import settings.settings as settings
|
||||||
|
from api.utils import get_password_hash
|
||||||
|
from db.internal import get_db_connection
|
||||||
|
|
||||||
|
anon_router = APIRouter(prefix="/api/anon", tags=["anon"])
|
||||||
|
|
||||||
|
|
||||||
|
@anon_router.post("/add/admin")
|
||||||
|
async def add_admin(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)]
|
||||||
|
):
|
||||||
|
if not settings.settings.allow_create_admins:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
hashed_password = get_password_hash(password)
|
||||||
|
return db.create_user(conn, username, hashed_password, "admin")
|
||||||
|
|
||||||
|
@anon_router.post("/add/user")
|
||||||
|
async def add_user(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)]
|
||||||
|
):
|
||||||
|
if not settings.settings.allow_create_users:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
hashed_password = get_password_hash(password)
|
||||||
|
return db.create_user(conn, username, hashed_password, "user")
|
||||||
@ -8,7 +8,7 @@ from psycopg2._psycopg import connection
|
|||||||
from api.models import Token
|
from api.models import Token
|
||||||
from api.utils import authenticate_user, create_access_token
|
from api.utils import authenticate_user, create_access_token
|
||||||
from db.internal import get_db_connection
|
from db.internal import get_db_connection
|
||||||
from settings import settings
|
from settings import startup_settings
|
||||||
|
|
||||||
auth_router = APIRouter(prefix="/api", tags=["auth"])
|
auth_router = APIRouter(prefix="/api", tags=["auth"])
|
||||||
|
|
||||||
@ -18,16 +18,16 @@ async def login(
|
|||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
conn: Annotated[connection, Depends(get_db_connection)]
|
conn: Annotated[connection, Depends(get_db_connection)]
|
||||||
) -> Token:
|
) -> Token:
|
||||||
user = authenticate_user(conn, form_data.username, form_data.password) # change db
|
password_correct = authenticate_user(conn, form_data.username, form_data.password)
|
||||||
if not user:
|
if not password_correct:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Incorrect username or password",
|
detail="Incorrect username or password",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token_expire_time = timedelta(minutes=settings.access_token_expiration_time)
|
access_token_expire_time = timedelta(minutes=startup_settings.access_token_expiration_time)
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.username}, expires_delta=access_token_expire_time
|
data={"sub": form_data.username}, expires_delta=access_token_expire_time
|
||||||
)
|
)
|
||||||
return Token(access_token=access_token, token_type="bearer")
|
return Token(access_token=access_token, token_type="bearer")
|
||||||
|
|||||||
64
src/api/general.py
Normal file
64
src/api/general.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
import settings.settings as settings
|
||||||
|
from api.models import User
|
||||||
|
from api.utils import get_current_user
|
||||||
|
from settings.settings import load_settings, reset_settings, save_settings
|
||||||
|
|
||||||
|
general_router = APIRouter(prefix="/api", tags=["status"])
|
||||||
|
|
||||||
|
|
||||||
|
@general_router.get('/ping')
|
||||||
|
async def ping():
|
||||||
|
return {'ok'}
|
||||||
|
|
||||||
|
|
||||||
|
@general_router.get('/settings/get')
|
||||||
|
async def get_settings():
|
||||||
|
return settings.settings
|
||||||
|
|
||||||
|
|
||||||
|
@general_router.post('/settings/update')
|
||||||
|
async def update_settings(data: dict, current_user: Annotated[User, Depends(get_current_user)]):
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
for key, value in data.items():
|
||||||
|
setattr(settings.settings, key, value)
|
||||||
|
return settings.settings
|
||||||
|
|
||||||
|
|
||||||
|
@general_router.get('/settings/reset')
|
||||||
|
async def reset_settings_api(current_user: Annotated[User, Depends(get_current_user)]):
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
reset_settings()
|
||||||
|
return settings.settings
|
||||||
|
|
||||||
|
|
||||||
|
@general_router.get('/settings/load_from_file')
|
||||||
|
async def load_settings_api(current_user: Annotated[User, Depends(get_current_user)]):
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
load_settings()
|
||||||
|
return settings.settings
|
||||||
|
|
||||||
|
@general_router.get('/settings/save_to_file')
|
||||||
|
async def save_settings_api(current_user: Annotated[User, Depends(get_current_user)]):
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
save_settings()
|
||||||
|
return settings.settings
|
||||||
@ -16,12 +16,14 @@ class User(BaseModel):
|
|||||||
def fill(self, params):
|
def fill(self, params):
|
||||||
self.username = params['username']
|
self.username = params['username']
|
||||||
self.password = params['password']
|
self.password = params['password']
|
||||||
|
self.role = params['role']
|
||||||
self.disabled = params['disabled']
|
self.disabled = params['disabled']
|
||||||
self.groups_ids = params['groups_ids']
|
self.groups_ids = params['groups_ids']
|
||||||
self.last_seen_at = params['last_seen_at']
|
self.last_seen_at = params['last_seen_at']
|
||||||
self.created_at = params['created_at']
|
self.created_at = params['created_at']
|
||||||
username: str = ''
|
username: str = ''
|
||||||
password: str = ''
|
password: str = ''
|
||||||
|
role: str = 'user'
|
||||||
disabled: bool = False
|
disabled: bool = False
|
||||||
groups_ids: list[str] | None = None
|
groups_ids: list[str] | None = None
|
||||||
last_seen_at: datetime | None = None
|
last_seen_at: datetime | None = None
|
||||||
|
|||||||
@ -1,8 +0,0 @@
|
|||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
status_router = APIRouter(prefix="/api", tags=["status"])
|
|
||||||
|
|
||||||
|
|
||||||
@status_router.get('/ping')
|
|
||||||
async def ping():
|
|
||||||
return {'ok'}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from api.models import User
|
|
||||||
from api.utils import get_current_user
|
|
||||||
|
|
||||||
test_router = APIRouter(prefix="/api", tags=["test"])
|
|
||||||
|
|
||||||
|
|
||||||
@test_router.get('/test-private')
|
|
||||||
async def test_private_func(token: Annotated[User, Depends(get_current_user)]):
|
|
||||||
return {'private nya'}
|
|
||||||
|
|
||||||
|
|
||||||
@test_router.post('/test')
|
|
||||||
async def test_func(text: str):
|
|
||||||
print(text)
|
|
||||||
if text == 'thighs':
|
|
||||||
raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED)
|
|
||||||
return {'nya'}
|
|
||||||
128
src/api/users.py
128
src/api/users.py
@ -1,11 +1,12 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from psycopg2._psycopg import connection
|
from psycopg2._psycopg import connection
|
||||||
|
|
||||||
import db.users as db
|
import db.users as db
|
||||||
|
import settings.settings as settings
|
||||||
from api.models import User
|
from api.models import User
|
||||||
from api.utils import get_current_user
|
from api.utils import get_current_user, get_password_hash
|
||||||
from db.internal import get_db_connection
|
from db.internal import get_db_connection
|
||||||
|
|
||||||
users_router = APIRouter(prefix="/api/users", tags=["users"])
|
users_router = APIRouter(prefix="/api/users", tags=["users"])
|
||||||
@ -15,19 +16,126 @@ users_router = APIRouter(prefix="/api/users", tags=["users"])
|
|||||||
async def read_users_me(current_user: Annotated[User, Depends(get_current_user)]):
|
async def read_users_me(current_user: Annotated[User, Depends(get_current_user)]):
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@users_router.post("/user")
|
@users_router.post("/user")
|
||||||
async def read_users_any(username: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
async def read_users_any(
|
||||||
|
username: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
user = User()
|
user = User()
|
||||||
user.fill(db.get_user(conn, username))
|
user_data = db.get_user(conn, username)
|
||||||
|
if user_data is None:
|
||||||
|
return HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No such user",
|
||||||
|
)
|
||||||
|
user.fill(user_data)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@users_router.post("/add")
|
@users_router.post("/add/admin")
|
||||||
async def add_user(username: str, password: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
async def add_admin(
|
||||||
return db.create_user(conn, username, password)
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if not settings.settings.allow_create_admins_by_admins:
|
||||||
|
if current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
hashed_password = get_password_hash(password)
|
||||||
|
return db.create_user(conn, username, hashed_password, "admin")
|
||||||
|
|
||||||
|
@users_router.post("/add/user")
|
||||||
|
async def add_user(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if not settings.settings.allow_create_users or current_user.role not in settings.settings.admin_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
hashed_password = get_password_hash(password)
|
||||||
|
return db.create_user(conn, username, hashed_password, "user")
|
||||||
|
|
||||||
|
|
||||||
@users_router.post("/delete")
|
@users_router.post("/delete")
|
||||||
async def delete_user(username: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
async def delete_user(
|
||||||
return db.delete_user(conn, username)
|
username: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if current_user.username == username or current_user.role in settings.settings.admin_roles:
|
||||||
|
return db.delete_user(conn, username)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@users_router.post("/update/disabled")
|
||||||
|
async def update_disabled(
|
||||||
|
username: str,
|
||||||
|
disabled: bool,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if current_user.role in settings.settings.admin_roles:
|
||||||
|
return db.update_user_disabled(conn, username, disabled)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@users_router.post("/update/username")
|
||||||
|
async def update_username(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if current_user.username == username or current_user.role in settings.settings.admin_roles:
|
||||||
|
return db.update_user_username(conn, username, password)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
@users_router.post("/update/password")
|
||||||
|
async def update_password(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
if current_user.username == username or current_user.role in settings.settings.admin_roles:
|
||||||
|
hashed_password = get_password_hash(password)
|
||||||
|
return db.update_user_password(conn, username, hashed_password)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not allowed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@users_router.get("/update/last_seen")
|
||||||
|
async def update_last_seen(
|
||||||
|
conn: Annotated[connection, Depends(get_db_connection)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
return db.update_user_last_seen(conn, current_user.username)
|
||||||
|
|||||||
@ -1,47 +1,50 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jwt.exceptions import InvalidTokenError
|
from jwt.exceptions import InvalidTokenError
|
||||||
from passlib.context import CryptContext
|
|
||||||
|
# from passlib.context import CryptContext
|
||||||
from psycopg2._psycopg import connection
|
from psycopg2._psycopg import connection
|
||||||
|
|
||||||
import db.users
|
import db.users
|
||||||
import settings.settings as settings
|
import settings.startup_settings as startup_settings
|
||||||
from api.models import TokenData, User
|
from api.models import TokenData, User
|
||||||
from db.internal import get_db_connection
|
from db.internal import get_db_connection
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str):
|
||||||
|
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||||
|
|
||||||
|
def get_password_hash(password: str):
|
||||||
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
def decode_token(token):
|
def decode_token(token):
|
||||||
return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
return jwt.decode(token, startup_settings.secret_key, algorithms=[startup_settings.algorithm])
|
||||||
|
|
||||||
def encode_token(payload):
|
def encode_token(payload):
|
||||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
return jwt.encode(payload, startup_settings.secret_key, algorithm=startup_settings.algorithm)
|
||||||
|
|
||||||
def verify_password(plain_password, hashed_password):
|
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(
|
def authenticate_user(
|
||||||
conn: connection,
|
conn: connection,
|
||||||
username: str,
|
username: str,
|
||||||
password: str
|
user_password: str
|
||||||
):
|
):
|
||||||
user = User()
|
db_user_password = db.users.get_user_password(conn, username)
|
||||||
userdata = db.users.get_user(conn, username)
|
if not user_password:
|
||||||
if not userdata:
|
|
||||||
return False
|
return False
|
||||||
if not verify_password(password, user.password):
|
if not verify_password(user_password, db_user_password):
|
||||||
return False
|
return False
|
||||||
user.fill(userdata)
|
return True
|
||||||
return user
|
|
||||||
|
|
||||||
def create_access_token(
|
def create_access_token(
|
||||||
data: dict,
|
data: dict,
|
||||||
@ -66,7 +69,6 @@ async def get_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
payload = decode_token(token)
|
payload = decode_token(token)
|
||||||
print(payload)
|
|
||||||
username = payload.get("sub")
|
username = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
@ -75,14 +77,16 @@ async def get_current_user(
|
|||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
user = User()
|
user = User()
|
||||||
user.fill(db.users.get_user(conn, username=token_data.username))
|
user_data = db.users.get_user(conn, token_data.username)
|
||||||
if user is None:
|
if user_data is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
|
user.fill(user_data)
|
||||||
|
|
||||||
if user.disabled:
|
if user.disabled:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Inactive user"
|
detail="User is disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from api.anon import anon_router
|
||||||
from api.auth import auth_router
|
from api.auth import auth_router
|
||||||
from api.status import status_router
|
from api.general import general_router
|
||||||
from api.tests import test_router
|
|
||||||
from api.users import users_router
|
from api.users import users_router
|
||||||
from db.internal import connect_db, disconnect_db
|
from db.internal import connect_db, disconnect_db
|
||||||
from settings import settings
|
from settings import startup_settings
|
||||||
|
from settings.settings import settings_down, settings_up
|
||||||
|
|
||||||
docs_url = None
|
docs_url = None
|
||||||
if settings.swagger_enabled:
|
if startup_settings.swagger_enabled:
|
||||||
docs_url = "/api/docs"
|
docs_url = "/api/docs"
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@ -19,12 +20,14 @@ app = FastAPI(
|
|||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
app.add_event_handler("startup", connect_db)
|
app.add_event_handler("startup", connect_db)
|
||||||
|
app.add_event_handler("startup", settings_up)
|
||||||
|
|
||||||
app.include_router(status_router)
|
app.include_router(general_router)
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
app.include_router(users_router)
|
app.include_router(users_router)
|
||||||
app.include_router(test_router)
|
app.include_router(anon_router)
|
||||||
|
|
||||||
app.add_event_handler("shutdown", disconnect_db)
|
app.add_event_handler("shutdown", disconnect_db)
|
||||||
|
app.add_event_handler("shutdown", settings_down)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
@ -4,18 +4,18 @@ import psycopg2
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from db.models import database
|
from db.models import database
|
||||||
from settings import settings
|
from settings import startup_settings
|
||||||
|
|
||||||
|
|
||||||
def connect_db():
|
def connect_db():
|
||||||
logger.info("Initializing DB connection")
|
logger.info("Initializing DB connection")
|
||||||
try:
|
try:
|
||||||
database.conn = psycopg2.connect(
|
database.conn = psycopg2.connect(
|
||||||
dbname=settings.db_name,
|
dbname=startup_settings.db_name,
|
||||||
user=settings.db_user,
|
user=startup_settings.db_user,
|
||||||
password=settings.db_password,
|
password=startup_settings.db_password,
|
||||||
host=settings.db_host,
|
host=startup_settings.db_host,
|
||||||
port=settings.db_port,
|
port=startup_settings.db_port,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize DB connection: {e}")
|
logger.error(f"Failed to initialize DB connection: {e}")
|
||||||
|
|||||||
@ -6,16 +6,17 @@ from psycopg2._psycopg import connection
|
|||||||
def create_user(
|
def create_user(
|
||||||
conn: connection,
|
conn: connection,
|
||||||
username: str,
|
username: str,
|
||||||
password: str
|
password: str,
|
||||||
|
role: str = "user"
|
||||||
):
|
):
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
insert into picrinth.users
|
insert into picrinth.users
|
||||||
(username, password, disabled, created_at)
|
(username, password, role, disabled, created_at)
|
||||||
values (%s, %s, false, now())
|
values (%s, %s, %s, false, now())
|
||||||
""",
|
""",
|
||||||
(username, password),
|
(username, password, role),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cur.rowcount > 0
|
return cur.rowcount > 0
|
||||||
@ -74,6 +75,22 @@ def check_user_disabled(
|
|||||||
|
|
||||||
# user updates
|
# user updates
|
||||||
|
|
||||||
|
def update_user_username(
|
||||||
|
conn: connection,
|
||||||
|
username: str,
|
||||||
|
newUsername: str
|
||||||
|
):
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
update picrinth.users
|
||||||
|
set username = %s
|
||||||
|
where username = %s;
|
||||||
|
""",
|
||||||
|
(newUsername, username),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
def update_user_password(
|
def update_user_password(
|
||||||
conn: connection,
|
conn: connection,
|
||||||
username: str,
|
username: str,
|
||||||
@ -91,22 +108,24 @@ def update_user_password(
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def update_user_username(
|
def update_user_disabled(
|
||||||
conn: connection,
|
conn: connection,
|
||||||
username: str,
|
username: str,
|
||||||
newUsername: str
|
disabled: bool
|
||||||
):
|
):
|
||||||
|
# if disabled = True -> user is disabled
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
update picrinth.users
|
update picrinth.users
|
||||||
set username = %s
|
set disabled = %s
|
||||||
where username = %s;
|
where username = %s
|
||||||
""",
|
""",
|
||||||
(newUsername, username),
|
(disabled, username),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def update_user_last_seen(
|
def update_user_last_seen(
|
||||||
conn: connection,
|
conn: connection,
|
||||||
username: str
|
username: str
|
||||||
@ -132,8 +151,9 @@ def get_user(
|
|||||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
select username, password, disabled,
|
select username, password, role,
|
||||||
groups_ids, last_seen_at, created_at
|
disabled, groups_ids,
|
||||||
|
last_seen_at, created_at
|
||||||
from picrinth.users
|
from picrinth.users
|
||||||
where username = %s
|
where username = %s
|
||||||
""",
|
""",
|
||||||
@ -154,4 +174,4 @@ def get_user_password(
|
|||||||
""",
|
""",
|
||||||
(username,),
|
(username,),
|
||||||
)
|
)
|
||||||
return cur.fetchone()
|
return cur.fetchone()[0] # type: ignore
|
||||||
|
|||||||
@ -1,22 +1,80 @@
|
|||||||
from decouple import config
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(string: str) -> bool:
|
class Settings(BaseModel):
|
||||||
if string.lower() == 'true':
|
def update(self, params):
|
||||||
return True
|
self.admin_roles = params['admin_roles']
|
||||||
return False
|
self.allow_create_admins_by_admins = params['allow_create_admins_by_admins']
|
||||||
|
self.allow_create_admins = params['allow_create_admins']
|
||||||
|
self.allow_create_users = params['allow_create_users']
|
||||||
|
admin_roles: list[str] = ['admin']
|
||||||
|
allow_create_admins_by_admins: bool = True
|
||||||
|
allow_create_admins: bool = True
|
||||||
|
allow_create_users: bool = True
|
||||||
|
|
||||||
# database
|
|
||||||
db_host = str(config('db_host', default='127.0.0.1'))
|
|
||||||
db_port = int(config('db_port', default=5432))
|
|
||||||
db_name = str(config('db_name', default='postgres'))
|
|
||||||
db_user = str(config('db_user', default='postgres'))
|
|
||||||
db_password = str(config('db_password', default='postgres'))
|
|
||||||
|
|
||||||
# auth
|
json_path = 'data/'
|
||||||
secret_key = str(config('secret_key'))
|
json_settings_name = 'settings.json'
|
||||||
algorithm = str(config('algorithm', 'HS256'))
|
|
||||||
access_token_expiration_time = int(config('access_token_expiration_time', default=10080))
|
|
||||||
|
|
||||||
# other settings
|
settings = Settings()
|
||||||
swagger_enabled = str_to_bool(str(config('swagger_enabled', 'false')))
|
|
||||||
|
|
||||||
|
def settings_up():
|
||||||
|
global settings, json_path, json_settings_name
|
||||||
|
logger.info('Configuring settings for startup')
|
||||||
|
try:
|
||||||
|
if not(os.path.exists(json_path)):
|
||||||
|
os.mkdir(json_path)
|
||||||
|
logger.debug(f'Created "{json_path}" directory')
|
||||||
|
|
||||||
|
if os.path.exists(json_path + json_settings_name):
|
||||||
|
load_settings()
|
||||||
|
else:
|
||||||
|
with open(json_path + json_settings_name, 'w') as f:
|
||||||
|
json.dump(settings.model_dump_json(), f, ensure_ascii = False, indent=4)
|
||||||
|
logger.info('Wrote settings to the JSON')
|
||||||
|
logger.info('Successfully configured settings')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to configure settings during startup: {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def settings_down():
|
||||||
|
global settings, json_path, json_settings_name
|
||||||
|
logger.info('Saving settings for shutdown')
|
||||||
|
try:
|
||||||
|
with open(json_path + json_settings_name, 'w') as f:
|
||||||
|
json.dump(settings.model_dump_json(), f, ensure_ascii = False, indent=4)
|
||||||
|
logger.info('Wrote settings to the JSON')
|
||||||
|
logger.success('Successfully saved settings')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to save settings during shutdown: {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def reset_settings():
|
||||||
|
global settings, json_path, json_settings_name
|
||||||
|
logger.info('Resetting settings')
|
||||||
|
print(settings)
|
||||||
|
settings = Settings()
|
||||||
|
print(settings)
|
||||||
|
with open(json_path + json_settings_name, 'w') as f:
|
||||||
|
json.dump(settings.model_dump_json(), f, ensure_ascii = False, indent=4)
|
||||||
|
logger.info('Wrote settings to the JSON')
|
||||||
|
|
||||||
|
|
||||||
|
def load_settings():
|
||||||
|
global settings, json_path, json_settings_name
|
||||||
|
logger.info('Loading settings')
|
||||||
|
with open(json_path + json_settings_name, 'r') as f:
|
||||||
|
json_settings = json.load(f)
|
||||||
|
settings = Settings.model_validate_json(json_settings)
|
||||||
|
logger.info('Loaded settings from the JSON')
|
||||||
|
|
||||||
|
def save_settings():
|
||||||
|
global settings, json_path, json_settings_name
|
||||||
|
with open(json_path + json_settings_name, 'w') as f:
|
||||||
|
json.dump(settings.model_dump_json(), f, ensure_ascii = False, indent=4)
|
||||||
|
logger.info('Wrote settings to the JSON')
|
||||||
|
|||||||
22
src/settings/startup_settings.py
Normal file
22
src/settings/startup_settings.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from decouple import config
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_bool(string: str) -> bool:
|
||||||
|
if string.lower() == 'true':
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# database
|
||||||
|
db_host = str(config('db_host', default='127.0.0.1'))
|
||||||
|
db_port = int(config('db_port', default=5432))
|
||||||
|
db_name = str(config('db_name', default='postgres'))
|
||||||
|
db_user = str(config('db_user', default='postgres'))
|
||||||
|
db_password = str(config('db_password', default='postgres'))
|
||||||
|
|
||||||
|
# auth
|
||||||
|
secret_key = str(config('secret_key'))
|
||||||
|
algorithm = str(config('algorithm', 'HS256'))
|
||||||
|
access_token_expiration_time = int(config('access_token_expiration_time', default=10080))
|
||||||
|
|
||||||
|
# other settings
|
||||||
|
swagger_enabled = str_to_bool(str(config('swagger_enabled', 'false')))
|
||||||
22
tables.sql
22
tables.sql
@ -1,7 +1,9 @@
|
|||||||
CREATE TABLE public.users (
|
CREATE TABLE picrinth.users (
|
||||||
id serial NOT NULL,
|
id serial not null,
|
||||||
username text NOT NULL,
|
username text not null,
|
||||||
"password" text NOT NULL,
|
"password" text not null,
|
||||||
|
"role" text not null default "user",
|
||||||
|
"disabled" bool not null,
|
||||||
groups_ids integer[] NULL,
|
groups_ids integer[] NULL,
|
||||||
last_seen_at timestamp with time zone NULL,
|
last_seen_at timestamp with time zone NULL,
|
||||||
created_at timestamp with time zone NULL,
|
created_at timestamp with time zone NULL,
|
||||||
@ -9,11 +11,11 @@ CREATE TABLE public.users (
|
|||||||
CONSTRAINT username_unique UNIQUE (username)
|
CONSTRAINT username_unique UNIQUE (username)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE public.groups (
|
CREATE TABLE picrinth.groups (
|
||||||
id serial NOT NULL,
|
id serial not null,
|
||||||
groupname text NOT NULL,
|
groupname text not null,
|
||||||
join_code text NOT NULL,
|
join_code text not null,
|
||||||
users_ids integer[] NULL,
|
users_ids integer[] null,
|
||||||
created_at timestamp with time zone NULL,
|
created_at timestamp with time zone null,
|
||||||
CONSTRAINT groupname_unique UNIQUE (username)
|
CONSTRAINT groupname_unique UNIQUE (username)
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user