More refactors

This commit is contained in:
lvrossem
2023-04-17 07:51:53 -06:00
parent 0bf764a0f4
commit 38eb9027d6
14 changed files with 142 additions and 89 deletions

View File

@@ -55,7 +55,10 @@ def register(db: Session, username: str, password: str, avatar_index: int):
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(
username=username, hashed_password=pwd_context.hash(password), avatar_index=avatar_index, playtime=0.0
username=username,
hashed_password=pwd_context.hash(password),
avatar_index=avatar_index,
playtime=0.0,
)
db.add(db_user)
db.commit()

View File

@@ -26,14 +26,10 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
if course_progress:
result.append(
CourseProgressParent(
progress=course_progress.progress, course=course
)
CourseProgressParent(progress=course_progress.progress, course=course)
)
else:
db.add(
CourseProgress(progress=0.0, course=course, owner_id=user.user_id)
)
db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id))
db.commit()
result.append(CourseProgressParent(progress=0.0, course=course))
@@ -44,9 +40,7 @@ def initialize_user(db: Session, user: User):
"""Create CourseProgress records with a value of 0 for a new user"""
for course in CourseEnum:
if course != CourseEnum.All:
db.add(
CourseProgress(progress=0.0, course=course, owner_id=user.user_id)
)
db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id))
db.commit()

View File

@@ -1,25 +1,52 @@
import datetime
from fastapi import HTTPException
from sqlalchemy import desc
from sqlalchemy import asc, desc, func
from sqlalchemy.orm import Session
from src.enums import MinigameEnum
from src.models import HighScore, User
from src.schemas.highscores import HighScoreBase
from src.schemas.saved_data import Score
from src.schemas.users import UserHighScore
def get_high_scores(
db: Session, minigame: MinigameEnum, user: User, nr_highest: int, mine_only: bool
def get_most_recent_high_scores(db: Session, minigame: MinigameEnum, amount: int):
"""Get the n most recent high scores of a given minigame"""
if amount < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
high_scores = []
if not minigame:
minigame = MinigameEnum.SpellingBee
high_scores = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.score_value))
.limit(amount)
.all()
)
for high_score in high_scores:
high_scores.append(
Score(score_value=high_score.score_value, time=high_score.time)
)
return high_scores
def get_highest_high_scores(
db: Session, minigame: MinigameEnum, user: User, amount: int, mine_only: bool
):
"""Get the n highest scores of a given minigame"""
if nr_highest < 1:
if amount < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
if mine_only:
if nr_highest > 1:
if amount > 1:
raise HTTPException(
status_code=400,
detail="nr_highest should be 1 when requesting high score of current user only",
detail="amount should be 1 when requesting high score of current user only",
)
else:
high_score = (
@@ -31,16 +58,15 @@ def get_high_scores(
)
if high_score:
return [
UserHighScore(
username=user.username,
Score(
score_value=high_score.score_value,
avatar_index=user.avatar_index,
time=high_score.time,
)
]
else:
return []
user_high_scores = []
high_scores = []
if not minigame:
minigame = MinigameEnum.SpellingBee
@@ -48,20 +74,15 @@ def get_high_scores(
high_scores = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.score_value))
.limit(nr_highest)
.order_by(asc(HighScore.time))
.limit(amount)
.all()
)
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(
username=owner.username,
score_value=high_score.score_value,
avatar_index=owner.avatar_index,
Score(score_value=high_score.score_value, time=high_score.time)
)
)
return user_high_scores
return high_scores
def create_high_score(
@@ -75,6 +96,7 @@ def create_high_score(
score_value=high_score.score_value,
minigame=minigame,
owner_id=user.user_id,
time=str(datetime.datetime.now()),
)
db.add(db_high_score)
db.commit()

33
src/crud/saved_data.py Normal file
View File

@@ -0,0 +1,33 @@
from sqlalchemy.orm import Session
from src.crud.highscores import get_highest_high_scores, get_most_recent_high_scores
from src.crud.users import get_user_by_username
from src.schemas.saved_data import *
from src.enums import MinigameEnum
def get_saved_data(db: Session, username: str):
"""Fetches all saved progress for the current user from the database"""
user = get_user_by_username(db, username)
minigames = []
courses = []
for minigame in MinigameEnum:
minigames.append(
SavedMinigameProgress(
minigame_index=minigame,
latest_scores = get_most_recent_high_scores(db, minigame, 10),
highest_scores = get_highest_high_scores(db, minigame, user, 10, False)
)
)
user_progress = SavedUser(
username=user.username,
avatar_index=user.avatar_index,
playtime=user.playtime,
minigames = minigames,
courses=courses
)
return user_progress

View File

@@ -9,10 +9,11 @@ sys.path.append("..")
from src.crud import authentication as crud_authentication
from src.crud import courseprogress as crud_courseprogress
from src.crud import highscores as crud_highscores
from src.crud import saved_data as crud_saved_data
from src.crud import users as crud_users
from src.database import Base, engine, get_db
from src.enums import CourseEnum, MinigameEnum
from src.schemas import courseprogress, highscores, users
from src.schemas import courseprogress, highscores, users, saved_data
app = FastAPI()
@@ -46,6 +47,13 @@ async def patch_current_user(
crud_users.patch_user(db, current_user_name, user)
@app.get("/saveddata", response_model=saved_data.SavedUser)
async def read_saved_data(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
return crud_saved_data.get_saved_data(db, current_user_name)
@app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
access_token = crud_authentication.register(
@@ -61,18 +69,21 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
return crud_authentication.login(db, user.username, user.password)
@app.get("/highscores/{minigame}", response_model=List[users.UserHighScore])
@app.get("/highscores/{minigame}", response_model=List[highscores.Score])
async def get_high_scores(
minigame: MinigameEnum,
nr_highest: Optional[int] = 1,
amount: Optional[int] = 1,
mine_only: Optional[bool] = True,
most_recent: Optional[bool] = False,
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
print(str(nr_highest))
print(str(mine_only))
if most_recent:
return crud_highscores.get_most_recent_high_scores(db, minigame, amount)
user = crud_users.get_user_by_username(db, current_user_name)
return crud_highscores.get_high_scores(db, minigame, user, nr_highest, mine_only)
return crud_highscores.get_highest_high_scores(
db, minigame, user, amount, mine_only
)
@app.put("/highscores/{minigame}", response_model=highscores.HighScore)

View File

@@ -1,4 +1,5 @@
from sqlalchemy import Column, Float, ForeignKey, Integer, String, Boolean
from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Integer,
String)
from sqlalchemy.orm import relationship
from src.database import Base
@@ -30,7 +31,7 @@ class HighScore(Base):
high_score_id = Column(Integer, primary_key=True, index=True)
score_value = Column(Float, nullable=False)
time = Column(String, nullable=False)
time = Column(DateTime, nullable=False)
minigame = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="high_scores")
@@ -59,5 +60,5 @@ class LearnableProgress(Base):
in_use = Column(Boolean, nullable=False)
name = Column(String, nullable=False)
progress = Column(Float, nullable=False)
course_id = Column(Integer, ForeignKey("course_progress.course_progress_id"))
course_progress_id = Column(Integer, ForeignKey("course_progress.course_progress_id"))
course = relationship("CourseProgress", back_populates="learnables")

View File

@@ -14,3 +14,12 @@ class HighScore(HighScoreBase):
class Config:
orm_mode = True
class Score(BaseModel):
score_id: int
score_value: float
time: str
class Config:
orm_mode = True

View File

@@ -1,14 +1,8 @@
from pydantic import BaseModel
from src.enums import CourseEnum, MinigameEnum
from typing import List
from pydantic import BaseModel
class SavedUser(BaseModel):
username: str
avatar_index: int = -1
playtime: float
minigames: List[SavedMinigameProgress]
courses: List[SavedCourseProgress]
from src.enums import CourseEnum, MinigameEnum
class Score(BaseModel):
@@ -34,6 +28,13 @@ class SavedCourseProgress(BaseModel):
class SavedMinigameProgress(BaseModel):
minigame_index: MinigameEnum
lastest_scores: List[Score]
latest_scores: List[Score]
highest_scores: List[Score]
class SavedUser(BaseModel):
username: str
avatar_index: int = -1
playtime: float
minigames: List[SavedMinigameProgress]
courses: List[SavedCourseProgress]

View File

@@ -15,6 +15,7 @@ username = "user1"
password = "password"
avatar_index = 1
async def register_user():
response = client.post(
"/register",

View File

@@ -1,9 +1,7 @@
import pytest
from fastapi.testclient import TestClient
from src.main import app, get_db
from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db
from tests.base import avatar_index, client, password, register_user, username
from tests.config.database import clear_db
@pytest.mark.asyncio

View File

@@ -1,12 +1,10 @@
import random
import pytest
from fastapi.testclient import TestClient
from src.enums import CourseEnum
from src.main import app, get_db
from tests.base import client, register_user
from tests.config.database import clear_db, override_get_db
from tests.config.database import clear_db
@pytest.mark.asyncio

View File

@@ -1,12 +1,10 @@
import random
import pytest
from fastapi.testclient import TestClient
from src.enums import MinigameEnum
from src.main import app, get_db
from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db
from tests.base import client, password, register_user
from tests.config.database import clear_db
@pytest.mark.asyncio

View File

@@ -1,13 +1,11 @@
import pytest
from fastapi.testclient import TestClient
from src.main import app, get_db
from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db
from tests.base import avatar_index, client, register_user, username
from tests.config.database import clear_db
patched_username = "New name"
patched_password = "New password"
patched_avatar = "New avatar"
patched_avatar_index = 2
@pytest.mark.asyncio
@@ -22,7 +20,7 @@ async def test_get_current_user():
assert response.status_code == 200
response = response.json()
assert response["username"] == username
assert response["avatar"] == avatar
assert response["avatar_index"] == avatar_index
@pytest.mark.asyncio
@@ -40,14 +38,7 @@ async def test_patch_user():
"""Test the patching of a user's username, password and avatar"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = response.json()["access_token"]
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
@@ -56,7 +47,7 @@ async def test_patch_user():
json={
"username": patched_username,
"password": patched_password,
"avatar": patched_avatar,
"avatar_index": patched_avatar_index,
},
headers=headers,
)
@@ -77,7 +68,7 @@ async def test_patch_user():
assert response.status_code == 200
# Correctness of password and username is already asserted by the login
assert response.json()["avatar"] == patched_avatar
assert response.json()["avatar_index"] == patched_avatar_index
@pytest.mark.asyncio
@@ -85,14 +76,7 @@ async def test_patch_user_with_empty_fields():
"""Patching a user with empty fields should fail"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = response.json()["access_token"]
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
@@ -101,18 +85,18 @@ async def test_patch_user_with_empty_fields():
json={
"username": patched_username,
"password": patched_password,
"avatar": "",
"avatar_index": "",
},
headers=headers,
)
assert response.status_code == 400
assert response.status_code == 422
response = client.patch(
"/users",
json={
"username": patched_username,
"password": "",
"avatar": patched_avatar,
"avatar_index": patched_avatar_index,
},
headers=headers,
)
@@ -123,7 +107,7 @@ async def test_patch_user_with_empty_fields():
json={
"username": "",
"password": patched_password,
"avatar": patched_avatar,
"avatar_index": patched_avatar_index,
},
headers=headers,
)