90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
from fastapi import HTTPException
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.orm import Session
|
|
|
|
from src.crud.highscores import (get_highest_high_scores,
|
|
get_most_recent_high_scores)
|
|
|
|
from src.crud.courseprogress import get_course_progress
|
|
from src.enums import CourseEnum, MinigameEnum
|
|
from src.models import CourseProgress, LearnableProgress, User
|
|
from src.schemas.courseprogress import SavedCourseProgress
|
|
from src.schemas.highscores import SavedMinigameProgress
|
|
from src.schemas.learnableprogress import SavedLearnableProgress
|
|
from src.schemas.users import SavedUser, UserCreate
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def check_empty_fields(username: str, password: str, avatar_index: int):
|
|
"Checks if any user fields are empty"
|
|
if avatar_index < 0:
|
|
raise HTTPException(status_code=400, detail="No avatar was provided")
|
|
if len(username) == 0:
|
|
raise HTTPException(status_code=400, detail="No username was provided")
|
|
if len(password) == 0:
|
|
raise HTTPException(status_code=400, detail="No password was provided")
|
|
|
|
|
|
def patch_user(db: Session, username: str, user: UserCreate):
|
|
"""Changes the username and/or the password of a User"""
|
|
# check_empty_fields(user.username, user.password, user.avatar_index)
|
|
db_user = get_user_by_username(db, username)
|
|
potential_duplicate = get_user_by_username(db, user.username)
|
|
if potential_duplicate:
|
|
if potential_duplicate.user_id != db_user.user_id:
|
|
raise HTTPException(status_code=400, detail="Username already registered")
|
|
|
|
if user.playtime < 0:
|
|
raise HTTPException(status_code=400, detail="Negative playtime is invalid")
|
|
|
|
if len(user.username) > 0:
|
|
db_user.username = user.username
|
|
|
|
if len(user.password) > 0:
|
|
db_user.hashed_password = pwd_context.hash(user.password)
|
|
|
|
if user.avatar_index > -1:
|
|
db_user.avatar_index = user.avatar_index
|
|
elif user.avatar_index < -1:
|
|
raise HTTPException(status_code=400, detail="Invalid avatar index")
|
|
|
|
db_user.playtime += user.playtime
|
|
db.commit()
|
|
|
|
|
|
def get_user_by_username(db: Session, username: str):
|
|
"""Fetches a User from the database by their username"""
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
|
|
def get_users(db: Session):
|
|
"""Fetch a list of all users"""
|
|
return db.query(User).all()
|
|
|
|
|
|
def get_saved_data(db: Session, username: str):
|
|
"""Fetches all saved progress for the current user from the database"""
|
|
user = get_user_by_username(db, username)
|
|
minigames = []
|
|
courses = get_course_progress(db, user, CourseEnum.All)
|
|
|
|
for minigame in MinigameEnum:
|
|
minigames.append(
|
|
SavedMinigameProgress(
|
|
minigame_index=minigame,
|
|
latest_scores=get_most_recent_high_scores(db, minigame, 10),
|
|
highest_scores=get_highest_high_scores(db, minigame, user, 10, False),
|
|
)
|
|
)
|
|
|
|
user_progress = SavedUser(
|
|
username=user.username,
|
|
avatar_index=user.avatar_index,
|
|
playtime=user.playtime,
|
|
minigames=minigames,
|
|
courses=courses,
|
|
)
|
|
|
|
return user_progress
|