diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 66eda01..3000804 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -3,7 +3,7 @@ from fastapi import HTTPException from src.enums import CourseEnum from src.models import CourseProgress, User -from src.schemas.courseprogress import CourseProgressBase +from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent def get_course_progress(db: Session, user: User, course: CourseEnum): @@ -26,7 +26,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): if course_progress: result.append( - CourseProgressBase( + CourseProgressParent( progress_value=course_progress.progress_value, course=course ) ) @@ -35,7 +35,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) ) db.commit() - result.append(CourseProgressBase(progress_value=0.0, course=course)) + result.append(CourseProgressParent(progress_value=0.0, course=course)) return result @@ -50,18 +50,18 @@ def initialize_user(db: Session, user: User): db.commit() -def patch_course_progress(db: Session, user: User, course_progress: CourseProgressBase): +def patch_course_progress(db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase): """Change the progress value for a given course""" if course_progress.progress_value > 1 or course_progress.progress_value < 0: raise HTTPException(status_code=400, detail="Invalid progress value") db_course_progress_list = [] - if course_progress.course != CourseEnum.All: + if course != CourseEnum.All: db_course_progress_list = ( db.query(CourseProgress) .filter( CourseProgress.owner_id == user.user_id, - CourseProgress.course == course_progress.course, + CourseProgress.course == course, ) .all() ) @@ -71,12 +71,12 @@ def patch_course_progress(db: Session, user: User, course_progress: CourseProgre .filter(CourseProgress.owner_id == user.user_id) .all() ) - print(f"LENGTH OF LIST OF {course_progress.course}: {len(db_course_progress_list)}") + for db_course_progress in db_course_progress_list: db_course_progress.progress_value = course_progress.progress_value db.commit() return [ - CourseProgressBase(course=db_cp.course, progress_value=db_cp.progress_value) + CourseProgressParent(course=db_cp.course, progress_value=db_cp.progress_value) for db_cp in db_course_progress_list ] diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 5809027..ffb2961 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -42,14 +42,14 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): return user_high_scores -def create_high_score(db: Session, user: User, high_score: HighScoreBase): +def create_high_score(db: Session, user: User, minigame: MinigameEnum, high_score: HighScoreBase): """Create a new high score for a given minigame""" def add_to_db(): """Helper function that adds new score to database; prevents code duplication""" db_high_score = HighScore( score_value=high_score.score_value, - minigame=high_score.minigame, + minigame=minigame, owner_id=user.user_id, ) db.add(db_high_score) @@ -61,7 +61,7 @@ def create_high_score(db: Session, user: User, high_score: HighScoreBase): db.query(HighScore) .filter( HighScore.owner_id == user.user_id, - HighScore.minigame == high_score.minigame, + HighScore.minigame == minigame, ) .first() ) diff --git a/src/main.py b/src/main.py index 71cf170..93cbd01 100644 --- a/src/main.py +++ b/src/main.py @@ -63,25 +63,26 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores/{minigame}", response_model=List[users.UserHighScore]) async def get_high_scores( - minigame: Optional[MinigameEnum] = None, + minigame: MinigameEnum, nr_highest: Optional[int] = None, db: Session = Depends(get_db), ): return crud_highscores.get_high_scores(db, minigame, nr_highest) -@app.post("/highscores", response_model=highscores.HighScore) +@app.post("/highscores/{minigame}", response_model=highscores.HighScore) async def create_high_score( + minigame: MinigameEnum, high_score: highscores.HighScoreBase, current_user_name=Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_highscores.create_high_score(db, current_user, high_score) + return crud_highscores.create_high_score(db, current_user, minigame, high_score) @app.get( - "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressBase] + "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent] ) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, @@ -92,14 +93,15 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) +@app.patch("/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent]) async def patch_course_progress( + course: CourseEnum, course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_courseprogress.patch_course_progress(db, current_user, course_progress) + return crud_courseprogress.patch_course_progress(db, current_user, course, course_progress) #### TESTING!! DELETE LATER diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index 768890d..2d8327c 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -5,10 +5,13 @@ from src.enums import CourseEnum class CourseProgressBase(BaseModel): progress_value: float + + +class CourseProgressParent(CourseProgressBase): course: CourseEnum -class CourseProgress(CourseProgressBase): +class CourseProgress(CourseProgressParent): course_progress_id: int owner_id: int diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 39632ba..22849b1 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -5,12 +5,12 @@ from src.enums import MinigameEnum class HighScoreBase(BaseModel): score_value: float - minigame: MinigameEnum class HighScore(HighScoreBase): high_score_id: int owner_id: int + minigame: MinigameEnum class Config: orm_mode = True diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 825d73e..d13cf05 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -12,22 +12,32 @@ app.dependency_overrides[get_db] = override_get_db client = TestClient(app) -username1 = "user1" -username2 = "user2" +username = "user1" password = "password" avatar = "lion" +async def register_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + return response.json()["access_token"] + + @pytest.mark.asyncio async def test_register(): - """LEAVE THIS TEST AT THE TOP OF THE FILE!""" """Test the register endpoint""" clear_db() response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password, "avatar": avatar}, + json={"username": username, "password": password, "avatar": avatar}, ) assert response.status_code == 200 @@ -37,10 +47,13 @@ async def test_register(): @pytest.mark.asyncio async def test_register_duplicate_name_should_fail(): """Test whether registering a user with an existing username fails""" + clear_db() + await register_user() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password, "avatar": avatar}, + json={"username": username, "password": password, "avatar": avatar}, ) assert response.status_code == 400 @@ -50,6 +63,8 @@ async def test_register_duplicate_name_should_fail(): @pytest.mark.asyncio async def test_register_without_username_should_fail(): """Test whether registering a user without passing a username fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, @@ -63,10 +78,12 @@ async def test_register_without_username_should_fail(): @pytest.mark.asyncio async def test_register_without_password_should_fail(): """Test whether registering a user without passing a password fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username2, "avatar": avatar}, + json={"username": username, "avatar": avatar}, ) assert response.status_code == 422 @@ -76,10 +93,12 @@ async def test_register_without_password_should_fail(): @pytest.mark.asyncio async def test_register_without_avatar_should_fail(): """Test whether registering a user without passing an avatar fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username2, "password": password}, + json={"username": username, "password": password}, ) # Not ideal that this is 400 instead of 422, but had no other choice than to give this field a default value @@ -90,10 +109,13 @@ async def test_register_without_avatar_should_fail(): @pytest.mark.asyncio async def test_login(): """Test the login endpoint""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password}, + json={"username": username, "password": password}, ) assert response.status_code == 200 @@ -102,11 +124,14 @@ async def test_login(): @pytest.mark.asyncio async def test_login_wrong_password_should_fail(): + clear_db() + await register_user() + wrong_password = password + "extra characters" response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": wrong_password}, + json={"username": username, "password": wrong_password}, ) assert response.status_code == 401 @@ -116,10 +141,13 @@ async def test_login_wrong_password_should_fail(): @pytest.mark.asyncio async def test_login_without_username_should_fail(): """Test whether logging in without passing a username fails""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1}, + json={"password": password}, ) assert response.status_code == 422 @@ -129,10 +157,13 @@ async def test_login_without_username_should_fail(): @pytest.mark.asyncio async def test_login_without_password_should_fail(): """Test whether logging in without passing a password fails""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1}, + json={"username": username}, ) assert response.status_code == 422 diff --git a/tests/test_courseprogress.py b/tests/test_courseprogress.py index ec41d56..33cf43f 100644 --- a/tests/test_courseprogress.py +++ b/tests/test_courseprogress.py @@ -52,6 +52,7 @@ async def test_register_creates_progress_of_zero(): @pytest.mark.asyncio async def test_get_all_returns_all(): + """Test whether the 'All'-course fetches all course progress values""" clear_db() token = await register_user() @@ -68,6 +69,7 @@ async def test_get_all_returns_all(): @pytest.mark.asyncio async def test_get_nonexisting_course_should_fail(): + """Test whether fetching the progress of a nonexisting course fails""" clear_db() token = await register_user() @@ -81,6 +83,7 @@ async def test_get_nonexisting_course_should_fail(): @pytest.mark.asyncio async def test_patch_course_progress(): + """Test whether patching the progress value of a course works properly""" clear_db() token = await register_user() @@ -91,9 +94,9 @@ async def test_patch_course_progress(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + f"/courseprogress/{course}", headers=headers, - json={"progress_value": progress_value, "course": course}, + json={"progress_value": progress_value}, ) assert response.status_code == 200 @@ -102,6 +105,7 @@ async def test_patch_course_progress(): @pytest.mark.asyncio async def test_patch_all_should_patch_all_courses(): + """Test whether patching the 'All'-course updates all progress values""" clear_db() token = await register_user() @@ -110,9 +114,9 @@ async def test_patch_all_should_patch_all_courses(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": progress_value, "course": "All"}, + json={"progress_value": progress_value}, ) assert response.status_code == 200 @@ -129,6 +133,7 @@ async def test_patch_all_should_patch_all_courses(): @pytest.mark.asyncio async def test_patch_nonexisting_course_should_fail(): + """Test whether patching a nonexisting course fails""" clear_db() token = await register_user() @@ -139,9 +144,9 @@ async def test_patch_nonexisting_course_should_fail(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + f"/courseprogress/{fake_course}", headers=headers, - json={"progress_value": progress_value, "course": fake_course}, + json={"progress_value": progress_value}, ) assert response.status_code == 422 @@ -149,6 +154,7 @@ async def test_patch_nonexisting_course_should_fail(): @pytest.mark.asyncio async def test_patch_course_with_invalid_value_should_fail(): + """Test whether patching a course progress value with an invalid value fails""" clear_db() token = await register_user() @@ -158,17 +164,17 @@ async def test_patch_course_with_invalid_value_should_fail(): too_low_progress_value = random.uniform(0, 1) - 2 response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": too_high_progress_value, "course": "All"}, + json={"progress_value": too_high_progress_value}, ) assert response.status_code == 400 response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": too_low_progress_value, "course": "All"}, + json={"progress_value": too_low_progress_value}, ) assert response.status_code == 400 diff --git a/tests/test_highscores.py b/tests/test_highscores.py index 5e145a4..a5a9a48 100644 --- a/tests/test_highscores.py +++ b/tests/test_highscores.py @@ -15,3 +15,21 @@ client = TestClient(app) username = "user1" password = "password" avatar = "lion" + + +async def register_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + return response.json()["access_token"] + + +@pytest.mark.asyncio +async def test_post_highscore(): + """Test whether posting a new high score succeeds""" + clear_db()