diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index c6d0a30..e3ba2b1 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -8,7 +8,14 @@ from schemas.courseprogress import CourseProgressBase def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" - if course != CourseEnum.All: + result = [] + courses_to_fetch = [course] + if course == CourseEnum.All: + all_courses_list = [course for course in CourseEnum] + courses_to_fetch = filter( + lambda course: course != CourseEnum.All, all_courses_list + ) + for course in courses_to_fetch: course_progress = ( db.query(CourseProgress) .filter( @@ -18,19 +25,19 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): ) if course_progress: - return [ + result.append( CourseProgressBase( - progress_value=course_progress.progress_value, - course=course_progress.course, + progress_value=course_progress.progress_value, course=course ) - ] + ) else: db.add( CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) ) db.commit() - return [CourseProgressBase(progress_value=0.0, course=course)] - return [] + result.append(CourseProgressBase(progress_value=0.0, course=course)) + + return result def initialize_user(db: Session, user: User): @@ -38,3 +45,27 @@ def initialize_user(db: Session, user: User): for course in course_enum_list: db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)) db.commit() + + +def patch_course_progress(db: Session, user: User, course_progress: CourseProgressBase): + """Change the progress value for a given course""" + db_course_progress_list = [] + if course_progress.course != CourseEnum.All: + db_course_progress_list = ( + db.query(CourseProgress) + .filter( + CourseProgress.owner_id == user.user_id, + CourseProgress.course == course_progress.course, + ) + .all() + ) + else: + db_course_progress_list = ( + db.query(CourseProgress) + .filter(CourseProgress.owner_id == user.user_id) + .all() + ) + + for db_course_progress in db_course_progress_list: + db_course_progress.progress_value = course_progress.progress_value + db.commit() diff --git a/src/main.py b/src/main.py index c7dea9c..63eae43 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,5 @@ from typing import List, Optional -import jwt from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session @@ -70,16 +69,6 @@ async def create_high_score( return crud_highscores.create_high_score(db, current_user, high_score) -#### TESTING!! DELETE LATER - - -@app.get("/protected") -async def protected_route( - current_user_name: str = Depends(crud_authentication.get_current_user_name), -): - return {"message": f"Hello, {current_user_name}!"} - - @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, @@ -92,9 +81,19 @@ async def get_course_progress( @app.patch("/courseprogress") async def get_course_progress( + course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), - course: Optional[CourseEnum] = CourseEnum.All, db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_courseprogress.get_course_progress(db, current_user, course) + return crud_courseprogress.patch_course_progress(db, current_user, course_progress) + + +#### TESTING!! DELETE LATER + + +@app.get("/protected") +async def protected_route( + current_user_name: str = Depends(crud_authentication.get_current_user_name), +): + return {"message": f"Hello, {current_user_name}!"}