users table endpoints. auth to fix
This commit is contained in:
7
src/__main__.py
Normal file
7
src/__main__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import uvicorn
|
||||
|
||||
from create_app import create_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
33
src/api/auth.py
Normal file
33
src/api/auth.py
Normal file
@ -0,0 +1,33 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from psycopg2._psycopg import connection
|
||||
|
||||
from api.models import Token
|
||||
from api.utils import authenticate_user, create_access_token
|
||||
from db.internal import get_db_connection
|
||||
from settings import settings
|
||||
|
||||
auth_router = APIRouter(prefix="/api", tags=["auth"])
|
||||
|
||||
|
||||
@auth_router.post("/token")
|
||||
async def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
conn: Annotated[connection, Depends(get_db_connection)]
|
||||
) -> Token:
|
||||
user = authenticate_user(conn, form_data.username, form_data.password) # change db
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token_expire_time = timedelta(minutes=settings.access_token_expiration_time)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expire_time
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
28
src/api/models.py
Normal file
28
src/api/models.py
Normal file
@ -0,0 +1,28 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
def fill(self, params):
|
||||
self.username = params['username']
|
||||
self.password = params['password']
|
||||
self.disabled = params['disabled']
|
||||
self.groups_ids = params['groups_ids']
|
||||
self.last_seen_at = params['last_seen_at']
|
||||
self.created_at = params['created_at']
|
||||
username: str = ''
|
||||
password: str = ''
|
||||
disabled: bool = False
|
||||
groups_ids: list[str] | None = None
|
||||
last_seen_at: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
8
src/api/status.py
Normal file
8
src/api/status.py
Normal file
@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
status_router = APIRouter(prefix="/api", tags=["status"])
|
||||
|
||||
|
||||
@status_router.get('/ping')
|
||||
async def ping():
|
||||
return {'ok'}
|
||||
21
src/api/tests.py
Normal file
21
src/api/tests.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from api.models import User
|
||||
from api.utils import get_current_user
|
||||
|
||||
test_router = APIRouter(prefix="/api", tags=["test"])
|
||||
|
||||
|
||||
@test_router.get('/test-private')
|
||||
async def test_private_func(token: Annotated[User, Depends(get_current_user)]):
|
||||
return {'private nya'}
|
||||
|
||||
|
||||
@test_router.post('/test')
|
||||
async def test_func(text: str):
|
||||
print(text)
|
||||
if text == 'thighs':
|
||||
raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED)
|
||||
return {'nya'}
|
||||
33
src/api/users.py
Normal file
33
src/api/users.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from psycopg2._psycopg import connection
|
||||
|
||||
import db.users as db
|
||||
from api.models import User
|
||||
from api.utils import get_current_user
|
||||
from db.internal import get_db_connection
|
||||
|
||||
users_router = APIRouter(prefix="/api/users", tags=["users"])
|
||||
|
||||
|
||||
@users_router.get("/me")
|
||||
async def read_users_me(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
return current_user
|
||||
|
||||
|
||||
@users_router.post("/user")
|
||||
async def read_users_any(username: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
||||
user = User()
|
||||
user.fill(db.get_user(conn, username))
|
||||
return user
|
||||
|
||||
|
||||
@users_router.post("/add")
|
||||
async def add_user(username: str, password: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
||||
return db.create_user(conn, username, password)
|
||||
|
||||
|
||||
@users_router.post("/delete")
|
||||
async def delete_user(username: str, conn: Annotated[connection, Depends(get_db_connection)]):
|
||||
return db.delete_user(conn, username)
|
||||
88
src/api/utils.py
Normal file
88
src/api/utils.py
Normal file
@ -0,0 +1,88 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from passlib.context import CryptContext
|
||||
from psycopg2._psycopg import connection
|
||||
|
||||
import db.users
|
||||
import settings.settings as settings
|
||||
from api.models import TokenData, User
|
||||
from db.internal import get_db_connection
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def decode_token(token):
|
||||
return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||
|
||||
def encode_token(payload):
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
|
||||
def authenticate_user(
|
||||
conn: connection,
|
||||
username: str,
|
||||
password: str
|
||||
):
|
||||
user = User()
|
||||
userdata = db.users.get_user(conn, username)
|
||||
if not userdata:
|
||||
return False
|
||||
if not verify_password(password, user.password):
|
||||
return False
|
||||
user.fill(userdata)
|
||||
return user
|
||||
|
||||
def create_access_token(
|
||||
data: dict,
|
||||
expires_delta: timedelta
|
||||
):
|
||||
encode_payload = data.copy()
|
||||
expire_moment = datetime.now(timezone.utc) + expires_delta
|
||||
encode_payload.update({"exp": expire_moment})
|
||||
encoded_jwt = encode_token(encode_payload)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
conn: Annotated[connection, Depends(get_db_connection)]
|
||||
):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
print(payload)
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
|
||||
user = User()
|
||||
user.fill(db.users.get_user(conn, username=token_data.username))
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
if user.disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
30
src/create_app.py
Normal file
30
src/create_app.py
Normal file
@ -0,0 +1,30 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from api.auth import auth_router
|
||||
from api.status import status_router
|
||||
from api.tests import test_router
|
||||
from api.users import users_router
|
||||
from db.internal import connect_db, disconnect_db
|
||||
from settings import settings
|
||||
|
||||
docs_url = None
|
||||
if settings.swagger_enabled:
|
||||
docs_url = "/api/docs"
|
||||
|
||||
app = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=docs_url,
|
||||
)
|
||||
|
||||
|
||||
def create_app():
|
||||
app.add_event_handler("startup", connect_db)
|
||||
|
||||
app.include_router(status_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(users_router)
|
||||
app.include_router(test_router)
|
||||
|
||||
app.add_event_handler("shutdown", disconnect_db)
|
||||
|
||||
return app
|
||||
44
src/db/internal.py
Normal file
44
src/db/internal.py
Normal file
@ -0,0 +1,44 @@
|
||||
import sys
|
||||
|
||||
import psycopg2
|
||||
from loguru import logger
|
||||
|
||||
from db.models import database
|
||||
from settings import settings
|
||||
|
||||
|
||||
def connect_db():
|
||||
logger.info("Initializing DB connection")
|
||||
try:
|
||||
database.conn = psycopg2.connect(
|
||||
dbname=settings.db_name,
|
||||
user=settings.db_user,
|
||||
password=settings.db_password,
|
||||
host=settings.db_host,
|
||||
port=settings.db_port,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DB connection: {e}")
|
||||
sys.exit(1)
|
||||
logger.success("Successfully initialized DB connection")
|
||||
|
||||
|
||||
def disconnect_db():
|
||||
logger.info("Closing DB connection")
|
||||
if database.conn:
|
||||
try:
|
||||
database.conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect from DB: {e}")
|
||||
return
|
||||
else:
|
||||
logger.error("Failed to disconnect from DB: no connection")
|
||||
logger.success("Successfully closed DB connection")
|
||||
|
||||
|
||||
def get_db_connection():
|
||||
if database.conn is not None:
|
||||
yield database.conn
|
||||
else:
|
||||
logger.error("No connection pool")
|
||||
sys.exit(1)
|
||||
8
src/db/models.py
Normal file
8
src/db/models.py
Normal file
@ -0,0 +1,8 @@
|
||||
from psycopg2._psycopg import connection
|
||||
|
||||
|
||||
class DataBase:
|
||||
conn: connection | None = None
|
||||
|
||||
|
||||
database = DataBase()
|
||||
157
src/db/users.py
Normal file
157
src/db/users.py
Normal file
@ -0,0 +1,157 @@
|
||||
import psycopg2.extras
|
||||
from psycopg2._psycopg import connection
|
||||
|
||||
# user create and delete
|
||||
|
||||
def create_user(
|
||||
conn: connection,
|
||||
username: str,
|
||||
password: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
insert into picrinth.users
|
||||
(username, password, disabled, created_at)
|
||||
values (%s, %s, false, now())
|
||||
""",
|
||||
(username, password),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def delete_user(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
delete from picrinth.users
|
||||
where username = %s
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
# user checks
|
||||
|
||||
def check_user_existence(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
select exists(
|
||||
select 1
|
||||
from picrinth.users
|
||||
where username = %s
|
||||
);
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
return cur.fetchone()
|
||||
|
||||
def check_user_disabled(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
select disabled
|
||||
from picrinth.users
|
||||
where username = %s;
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
return cur.fetchone()
|
||||
|
||||
|
||||
# user updates
|
||||
|
||||
def update_user_password(
|
||||
conn: connection,
|
||||
username: str,
|
||||
password: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
update picrinth.users
|
||||
set password = %s
|
||||
where username = %s
|
||||
""",
|
||||
(password, username),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def update_user_username(
|
||||
conn: connection,
|
||||
username: str,
|
||||
newUsername: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
update picrinth.users
|
||||
set username = %s
|
||||
where username = %s;
|
||||
""",
|
||||
(newUsername, username),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def update_user_last_seen(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
update picrinth.users
|
||||
set last_seen_at = now()
|
||||
where username = %s
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# user receiving
|
||||
|
||||
def get_user(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
select username, password, disabled,
|
||||
groups_ids, last_seen_at, created_at
|
||||
from picrinth.users
|
||||
where username = %s
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
return cur.fetchone()
|
||||
|
||||
def get_user_password(
|
||||
conn: connection,
|
||||
username: str
|
||||
):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
select password
|
||||
from picrinth.users
|
||||
where username = %s
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
return cur.fetchone()
|
||||
0
src/settings/consts.py
Normal file
0
src/settings/consts.py
Normal file
22
src/settings/settings.py
Normal file
22
src/settings/settings.py
Normal file
@ -0,0 +1,22 @@
|
||||
from decouple import config
|
||||
|
||||
|
||||
def str_to_bool(string: str) -> bool:
|
||||
if string.lower() == 'true':
|
||||
return True
|
||||
return False
|
||||
|
||||
# database
|
||||
db_host = str(config('db_host', default='127.0.0.1'))
|
||||
db_port = int(config('db_port', default=5432))
|
||||
db_name = str(config('db_name', default='postgres'))
|
||||
db_user = str(config('db_user', default='postgres'))
|
||||
db_password = str(config('db_password', default='postgres'))
|
||||
|
||||
# auth
|
||||
secret_key = str(config('secret_key'))
|
||||
algorithm = str(config('algorithm', 'HS256'))
|
||||
access_token_expiration_time = int(config('access_token_expiration_time', default=10080))
|
||||
|
||||
# other settings
|
||||
swagger_enabled = str_to_bool(str(config('swagger_enabled', 'false')))
|
||||
Reference in New Issue
Block a user