from fastapi import HTTPException from sqlalchemy.orm import Session from src.enums import CourseEnum from src.models import CourseProgress, User from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" result = [] courses_to_fetch = [course] if course == CourseEnum.All: all_courses_list = [course for course in CourseEnum] courses_to_fetch = filter( lambda course: course != CourseEnum.All, all_courses_list ) for course in courses_to_fetch: course_progress = ( db.query(CourseProgress) .filter( CourseProgress.owner_id == user.user_id, CourseProgress.course == course ) .first() ) if course_progress: result.append( CourseProgressParent(progress=course_progress.progress, course=course) ) else: db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id)) db.commit() result.append(CourseProgressParent(progress=0.0, course=course)) return result def initialize_user(db: Session, user: User): """Create CourseProgress records with a value of 0 for a new user""" for course in CourseEnum: if course != CourseEnum.All: db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id)) db.commit() def patch_course_progress( db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase ): """Change the progress value for a given course""" if course_progress.progress > 1 or course_progress.progress < 0: raise HTTPException(status_code=400, detail="Invalid progress value") db_course_progress_list = [] if course != CourseEnum.All: db_course_progress_list = ( db.query(CourseProgress) .filter( CourseProgress.owner_id == user.user_id, CourseProgress.course == course, ) .all() ) else: db_course_progress_list = ( db.query(CourseProgress) .filter(CourseProgress.owner_id == user.user_id) .all() ) for db_course_progress in db_course_progress_list: db_course_progress.progress = course_progress.progress db.commit() return [ CourseProgressParent(course=db_cp.course, progress=db_cp.progress) for db_cp in db_course_progress_list ]