diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 5986a57..e70fb85 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -1,7 +1,8 @@ from datetime import datetime, timedelta import jwt -from fastapi import HTTPException +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from crud.users import get_user_by_username, pwd_context @@ -15,6 +16,25 @@ jwt_secret = "secret_key" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month +bearer_scheme = HTTPBearer() + + +def get_current_user_name( + token: HTTPAuthorizationCredentials = Depends(bearer_scheme), +): + try: + payload = jwt.decode( + token.credentials, + jwt_secret, + algorithms=[ALGORITHM], + ) + username = payload.get("sub") + if username is None: + raise HTTPException(status_code=401, detail="Invalid JWT token") + return username + except jwt.exceptions.DecodeError: + raise HTTPException(status_code=401, detail="Invalid JWT token") + def authenticate_user(db: Session, username: str, password: str): """Checks whether the provided credentials match with an existing User""" @@ -29,10 +49,14 @@ def authenticate_user(db: Session, username: str, password: str): def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" + if avatar == "": + raise HTTPException(status_code=400, detail="No avatar was provided") db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") - db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar) + db_user = User( + username=username, hashed_password=pwd_context.hash(password), avatar=avatar + ) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 6a2f17e..c6d0a30 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -2,7 +2,8 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from enums import CourseEnum, course_enum_list -from models import User, CourseProgress +from models import CourseProgress, User +from schemas.courseprogress import CourseProgressBase def get_course_progress(db: Session, user: User, course: CourseEnum): @@ -15,21 +16,25 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): ) .first() ) + if course_progress: return [ CourseProgressBase( - progress_value = course_progress.progress_value, - course = course_progress.course, + progress_value=course_progress.progress_value, + course=course_progress.course, ) ] else: - return [CourseProgressBase(progress_value = 0, course=course)] + db.add( + CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) + ) + db.commit() + return [CourseProgressBase(progress_value=0.0, course=course)] return [] def initialize_user(db: Session, user: User): + """Create CourseProgress records with a value of 0 for a new user""" for course in course_enum_list: - db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id)) + db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)) db.commit() - - diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 6471c34..17b2cce 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -31,19 +31,24 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): for high_score in high_scores: owner = db.query(User).filter(User.user_id == high_score.owner_id).first() user_high_scores.append( - UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar) + UserHighScore( + username=owner.username, + score_value=high_score.score_value, + avatar=owner.avatar, + ) ) return user_high_scores def create_high_score(db: Session, user: User, high_score: HighScoreBase): """Create a new high score for a given minigame""" + def add_to_db(): """Helper function that adds new score to database; prevents code duplication""" db_high_score = HighScore( - score_value = high_score.score_value, - minigame = high_score.minigame, - owner_id = user.user_id, + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=user.user_id, ) db.add(db_high_score) db.commit() diff --git a/src/crud/users.py b/src/crud/users.py index 4817d27..50cadd7 100644 --- a/src/crud/users.py +++ b/src/crud/users.py @@ -18,6 +18,7 @@ def patch_user(db: Session, username: str, user: UserCreate): db_user.username = user.username db_user.hashed_password = pwd_context.hash(user.password) + db_user.avatar = user.avatar db.commit() diff --git a/src/database.py b/src/database.py index 533adf5..aca9992 100644 --- a/src/database.py +++ b/src/database.py @@ -9,3 +9,11 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/src/enums.py b/src/enums.py index 088d795..3b6ecd4 100644 --- a/src/enums.py +++ b/src/enums.py @@ -28,20 +28,21 @@ class MinigameEnum(StrEnum): class CourseEnum(StrEnum): Fingerspelling = "Fingerspelling" - #Basics = "Basics" + Basics = "Basics" Hobbies = "Hobbies" Animals = "Animals" Colors = "Colors" FruitsVegetables = "FruitsVegetables" All = "All" + # This is needed because for some reason iterating over an enum doesn't work properly... course_enum_list = [ CourseEnum.Fingerspelling, - #CourseEnum.Basics, + # CourseEnum.Basics, CourseEnum.Hobbies, CourseEnum.Animals, CourseEnum.Colors, CourseEnum.FruitsVegetables, - CourseEnum.All + CourseEnum.All, ] diff --git a/src/main.py b/src/main.py index 5a9f139..c7dea9c 100644 --- a/src/main.py +++ b/src/main.py @@ -2,46 +2,19 @@ from typing import List, Optional import jwt from fastapi import Depends, FastAPI, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from crud import authentication as crud_authentication from crud import courseprogress as crud_courseprogress from crud import highscores as crud_highscores from crud import users as crud_users -from database import SessionLocal, engine +from database import SessionLocal, engine, get_db from enums import CourseEnum, MinigameEnum from models import Base from schemas import courseprogress, highscores, users app = FastAPI() Base.metadata.create_all(bind=engine) -bearer_scheme = HTTPBearer() - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def get_current_user_name( - token: HTTPAuthorizationCredentials = Depends(bearer_scheme), -): - try: - payload = jwt.decode( - token.credentials, - crud_authentication.jwt_secret, - algorithms=[crud_authentication.ALGORITHM], - ) - username = payload.get("sub") - if username is None: - raise HTTPException(status_code=401, detail="Invalid JWT token") - return username - except jwt.exceptions.DecodeError: - raise HTTPException(status_code=401, detail="Invalid JWT token") @app.get("/") @@ -57,15 +30,17 @@ async def read_users(db: Session = Depends(get_db)): @app.patch("/users") async def patch_current_user( user: users.UserCreate, - current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name=Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), ): crud_users.patch_user(db, current_user_name, user) @app.post("/register") async def register(user: users.UserCreate, db: Session = Depends(get_db)): - access_token = crud_authentication.register(db, user.username, user.password, user.avatar) + access_token = crud_authentication.register( + db, user.username, user.password, user.avatar + ) user = crud_users.get_user_by_username(db, user.username) crud_courseprogress.initialize_user(db, user) return access_token @@ -80,7 +55,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): async def get_high_scores( minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return crud_highscores.get_high_scores(db, minigame, nr_highest) @@ -88,8 +63,8 @@ async def get_high_scores( @app.post("/highscores", response_model=highscores.HighScore) async def create_high_score( high_score: highscores.HighScoreBase, - current_user_name = Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name=Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) return crud_highscores.create_high_score(db, current_user, high_score) @@ -99,15 +74,27 @@ async def create_high_score( @app.get("/protected") -async def protected_route(current_user_name=Depends(get_current_user_name)): +async def protected_route( + current_user_name: str = Depends(crud_authentication.get_current_user_name), +): return {"message": f"Hello, {current_user_name}!"} @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, - current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name: str = Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), +): + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_courseprogress.get_course_progress(db, current_user, course) + + +@app.patch("/courseprogress") +async def get_course_progress( + current_user_name: str = Depends(crud_authentication.get_current_user_name), + course: Optional[CourseEnum] = CourseEnum.All, + db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) return crud_courseprogress.get_course_progress(db, current_user, course) diff --git a/src/models.py b/src/models.py index 948f6b8..d2f1f8d 100644 --- a/src/models.py +++ b/src/models.py @@ -7,6 +7,8 @@ from enums import CourseEnum, MinigameEnum, StrEnumType class User(Base): + """The database model for users""" + __tablename__ = "users" user_id = Column(Integer, primary_key=True, index=True) @@ -23,6 +25,8 @@ class User(Base): class HighScore(Base): + """The database model for high scores""" + __tablename__ = "high_scores" high_score_id = Column(Integer, primary_key=True, index=True) @@ -33,6 +37,8 @@ class HighScore(Base): class CourseProgress(Base): + """The database model for course progress""" + __tablename__ = "course_progress" course_progress_id = Column(Integer, primary_key=True, index=True) diff --git a/src/schemas/users.py b/src/schemas/users.py index a14c4a3..3077af7 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class UserBase(BaseModel): username: str - avatar: str + avatar: str = "" class User(UserBase):